{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# About this Notebook\n", "\n", "Temporal Regularized Matrix Factorization (TRMF) is an effective tool for imputing missing data within a given multivariate time series and forecasting time series with missing values. This approach is from the following literature:\n", "\n", "> Hsiang-Fu Yu, Nikhil Rao, Inderjit S. Dhillon, 2016. [**Temporal regularized matrix factorization for high-dimensional time series prediction**](http://www.cs.utexas.edu/~rofuyu/papers/tr-mf-nips.pdf). 30th Conference on Neural Information Processing Systems (*NIPS 2016*), Barcelona, Spain.\n", "\n", "**Acknowledgement**: We would like to thank\n", "\n", "- Antony Masso Lussier (HEC Montreal)\n", "\n", "for providing helpful suggestion and discussion. Thank you!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Quick Run\n", "\n", "This notebook is publicly available for any usage at our data imputation project. Please click [**transdim**](https://github.com/xinychen/transdim).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Organization: Matrix Structure\n", "\n", "In this post, we consider a dataset of $m$ discrete time series $\\boldsymbol{y}_{i}\\in\\mathbb{R}^{f},i\\in\\left\\{1,2,...,m\\right\\}$. The time series may have missing elements. We express spatio-temporal dataset as a matrix $Y\\in\\mathbb{R}^{m\\times f}$ with $m$ rows (e.g., locations) and $f$ columns (e.g., discrete time intervals),\n", "\n", "$$Y=\\left[ \\begin{array}{cccc} y_{11} & y_{12} & \\cdots & y_{1f} \\\\ y_{21} & y_{22} & \\cdots & y_{2f} \\\\ \\vdots & \\vdots & \\ddots & \\vdots \\\\ y_{m1} & y_{m2} & \\cdots & y_{mf} \\\\ \\end{array} \\right]\\in\\mathbb{R}^{m\\times f}.$$\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## TRMF model\n", "\n", "Temporal Regularized Matrix Factorization (TRMF) is an approach to incorporate temporal dependencies into commonly-used matrix factorization model. The temporal dependencies are described among ${\\boldsymbol{x}_t}$ explicitly. Such approach takes the form:\n", "\n", "$$\\boldsymbol{x}_{t}\\approx\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l},$$\n", "where this autoregressive (AR) is specialized by a lag set $\\mathcal{L}=\\left\\{l_1,l_2,...,l_d\\right\\}$ (e.g., $\\mathcal{L}=\\left\\{1,2,144\\right\\}$) and weights $\\boldsymbol{\\theta}_{l}\\in\\mathbb{R}^{r},\\forall l$, and we further define\n", "\n", "$$\\mathcal{R}_{AR}\\left(X\\mid \\mathcal{L},\\Theta,\\eta\\right)=\\frac{1}{2}\\sum_{t=l_d+1}^{f}\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)^\\top\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)+\\frac{\\eta}{2}\\sum_{t=1}^{f}\\boldsymbol{x}_{t}^\\top\\boldsymbol{x}_{t}.$$\n", "\n", "Thus, TRMF-AR is given by solving\n", "\n", "$$\\min_{W,X,\\Theta}\\frac{1}{2}\\underbrace{\\sum_{(i,t)\\in\\Omega}\\left(y_{it}-\\boldsymbol{w}_{i}^T\\boldsymbol{x}_{t}\\right)^2}_{\\text{sum of squared residual errors}}+\\lambda_{w}\\underbrace{\\mathcal{R}_{w}\\left(W\\right)}_{W-\\text{regularizer}}+\\lambda_{x}\\underbrace{\\mathcal{R}_{AR}\\left(X\\mid \\mathcal{L},\\Theta,\\eta\\right)}_{\\text{AR-regularizer}}+\\lambda_{\\theta}\\underbrace{\\mathcal{R}_{\\theta}\\left(\\Theta\\right)}_{\\Theta-\\text{regularizer}}$$\n", "\n", "where $\\mathcal{R}_{w}\\left(W\\right)=\\frac{1}{2}\\sum_{i=1}^{m}\\boldsymbol{w}_{i}^\\top\\boldsymbol{w}_{i}$ and $\\mathcal{R}_{\\theta}\\left(\\Theta\\right)=\\frac{1}{2}\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}^\\top\\boldsymbol{\\theta}_{l}$ are regularization terms." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define TRMF model with `Numpy`\n", "\n", "Observing the optimization problem of TRMF model as mentioned above, we categorize the parameters within this model as **parameters** (i.e., `init_para` in the TRMF function) and **hyperparameters** (i.e., `init_hyper`).\n", "\n", "- **Parameters** include spatial matrix $W$, temporal matrix $X$, and AR coefficients $\\Theta$.\n", "- **Hyperparameters** include weight parameters on some regularizers, i.e., $\\lambda_w$, $\\lambda_x$, $\\lambda_\\theta$, and $\\eta$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How to understand Python code of TRMF?\n", "\n", "#### Update spatial matrix $W$\n", "\n", "We write Python code for updating spatial matrix as follows,\n", "\n", "```python\n", "for i in range(dim1):\n", " pos0 = np.where(sparse_mat[i, :] != 0)\n", " Xt = X[pos0[0], :]\n", " vec0 = Xt.T @ sparse_mat[i, pos0[0]]\n", " mat0 = inv(Xt.T @ Xt + lambda_w * np.eye(rank))\n", " W[i, :] = mat0 @ vec0\n", "```\n", "\n", "For your better understanding of these codes, let us see what happened in each line. Recall that the equation for updating $W$ is\n", "$$\\boldsymbol{w}_{i} \\Leftarrow\\left(\\sum_{t:(i, t) \\in \\Omega} \\boldsymbol{x}_{t} \\boldsymbol{x}_{t}^{T}+\\lambda_{w} I\\right)^{-1} \\sum_{t:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{x}_{t}$$\n", "from the optimizization problem:\n", "$$\\min _{W} \\frac{1}{2} \\underbrace{\\sum_{(i, t) \\in \\Omega}\\left(y_{i t}-\\boldsymbol{w}_{i}^{T} \\boldsymbol{x}_{t}\\right)^{2}}_{\\text {sum of squared residual errors }}+\\frac{1}{2} \\lambda_{w} \\underbrace{\\sum_{i=1}^{m} \\boldsymbol{w}_{i}^{T} \\boldsymbol{w}_{i}}_{\\text{sum of squared entries}}.$$\n", "\n", "As can be seen,\n", "\n", "- `vec0 = Xt.T @ sparse_mat[i, pos0[0]])` corresponds to $$\\sum_{t:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{x}_{t}.$$\n", "\n", "- `mat0 = inv(Xt.T @ Xt + lambda_w * np.eye(rank))` corresponds to $$\\left(\\sum_{t:(i, t) \\in \\Omega} \\boldsymbol{x}_{t} \\boldsymbol{x}_{t}^{T}+\\lambda_{w} I\\right)^{-1}.$$\n", "\n", "- `W[i, :] = mat0 @ vec0` corresponds to the update:\n", "$$\\boldsymbol{w}_{i} \\Leftarrow\\left(\\sum_{t:(i, t) \\in \\Omega} \\boldsymbol{x}_{t} \\boldsymbol{x}_{t}^{T}+\\lambda_{w} I\\right)^{-1} \\sum_{t:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{x}_{t}.$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Update temporal matrix $X$\n", "\n", "We write Python code for updating temporal matrix as follows,\n", "\n", "```python\n", "for t in range(dim2):\n", " pos0 = np.where(sparse_mat[:, t] != 0)\n", " Wt = W[pos0[0], :]\n", " Mt = np.zeros((rank, rank))\n", " Nt = np.zeros(rank)\n", " if t < np.max(time_lags):\n", " Pt = np.zeros((rank, rank))\n", " Qt = np.zeros(rank)\n", " else:\n", " Pt = np.eye(rank)\n", " Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])\n", " if t < dim2 - np.min(time_lags):\n", " if t >= np.max(time_lags) and t < dim2 - np.max(time_lags):\n", " index = list(range(0, d))\n", " else:\n", " index = list(np.where((t + time_lags >= np.max(time_lags)) & (t + time_lags < dim2)))[0]\n", " for k in index:\n", " Ak = theta[k, :]\n", " Mt += np.diag(Ak ** 2)\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " Nt += np.multiply(Ak, X[t + time_lags[k], :]\n", " - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))\n", " vec0 = Wt.T @ sparse_mat[pos0[0], t] + lambda_x * Nt + lambda_x * Qt\n", " mat0 = inv(Wt.T @ Wt + lambda_x * Mt + lambda_x * Pt + lambda_x * eta * np.eye(rank))\n", " X[t, :] = mat0 @ vec0\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These codes seem to be very complicated. Let us first see the optimization problem for getting a closed-form update of $X$:\n", "$$\\min_{W,X,\\Theta}\\frac{1}{2}\\underbrace{\\sum_{(i,t)\\in\\Omega}\\left(y_{it}-\\boldsymbol{w}_{i}^T\\boldsymbol{x}_{t}\\right)^2}_{\\text{sum of squared residual errors}}+\\underbrace{\\frac{1}{2}\\lambda_{x}\\sum_{t=l_d+1}^{f}\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)^\\top\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)+\\frac{1}{2}\\lambda_{x}\\eta\\sum_{t=1}^{f}\\boldsymbol{x}_{t}^\\top\\boldsymbol{x}_{t}}_{\\text{AR-term}}+\\underbrace{\\frac{1}{2}\\lambda_{\\theta}\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}^\\top\\boldsymbol{\\theta}_{l}}_{\\Theta-\\text{term}}.$$\n", "\n", "- For $t=1,...,l_d$, update of $X$ is\n", "$$\\boldsymbol{x}_{t} \\Leftarrow\\left(\\sum_{i:(i, t) \\in \\Omega} \\boldsymbol{w}_{i} \\boldsymbol{w}_{i}^{T}+\\lambda_{x} \\eta I\\right)^{-1} \\sum_{i:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{w}_{i}.$$\n", "- For $t=l_d+1,...,f$, update of $X$ is\n", "$${\\boldsymbol{x}_{t}\\Leftarrow\\left(\\sum_{i:(i,t)\\in\\Omega}\\boldsymbol{w}_{i}\\boldsymbol{w}_{i}^{T}+\\lambda_xI+\\lambda_x\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\text{diag}(\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\theta}_{h})+\\lambda_x\\eta I\\right)^{-1}}{\\left(\\sum_{i:(i,t)\\in\\Omega}y_{it}\\boldsymbol{w}_{i}+\\lambda_x\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}+\\lambda_x\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\psi}_{t+h}\\right)}.$$\n", "\n", "Then, as can be seen,\n", "\n", "- `Mt += np.diag(Ak ** 2)` corresponds to $$\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\text{diag}(\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\theta}_{h}).$$\n", "\n", "- `Nt += np.multiply(Ak, X[t + time_lags[k], :] - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))` corresponds to $$\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\psi}_{t+h}.$$\n", "\n", "- `Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])` corresponds to $$\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}.$$\n", "-`X[t, :] = mat0 @ vec0` corresponds to the update of $X$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Update AR coefficients $\\Theta$\n", "\n", "We write Python code for updating temporal matrix as follows,\n", "\n", "```python\n", "for k in range(d):\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " mat0 = np.zeros((dim2 - np.max(time_lags), rank))\n", " for L in range(d):\n", " mat0 += X[np.max(time_lags) - time_lags[L] : dim2 - time_lags[L] , :] @ np.diag(theta0[L, :])\n", " VarPi = X[np.max(time_lags) : dim2, :] - mat0\n", " var1 = np.zeros((rank, rank))\n", " var2 = np.zeros(rank)\n", " for t in range(np.max(time_lags), dim2):\n", " B = X[t - time_lags[k], :]\n", " var1 += np.diag(np.multiply(B, B))\n", " var2 += np.diag(B) @ VarPi[t - np.max(time_lags), :]\n", " theta[k, :] = inv(var1 + lambda_theta * np.eye(rank) / lambda_x) @ var2\n", "```\n", "\n", "For your better understanding of these codes, let us see what happened in each line. Recall that the equation for updating $\\theta$ is\n", "$$\n", "\\color{red} {\\boldsymbol{\\theta}_{h}\\Leftarrow\\left(\\sum_{t=l_d+1}^{f}\\text{diag}(\\boldsymbol{x}_{t-h}\\circledast \\boldsymbol{x}_{t-h})+\\frac{\\lambda_{\\theta}}{\\lambda_x}I\\right)^{-1}\\left(\\sum_{t=l_d+1}^{f}{\\boldsymbol{\\pi}_{t}^{h}}\\circledast \\boldsymbol{x}_{t-h}\\right)}\n", "$$\n", "where $\\boldsymbol{\\pi}_{t}^{h}=\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L},l\\neq h}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}$ from the optimizization problem:\n", "$$\n", "\\min_{\\Theta}\\frac{1}{2}\\lambda_{x}\\underbrace{\\sum_{t=l_d+1}^{f}\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)^\\top\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)}_{\\text{sum of squared residual errors}}+\\frac{1}{2}\\lambda_{\\theta}\\underbrace{\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}^\\top\\boldsymbol{\\theta}_{l}}_{\\text{sum of squared entries}}\n", "$$\n", "\n", "As can be seen,\n", "- `mat0 += X[np.max(time_lags) - time_lags[L] : dim2 - time_lags[L] , :] @ np.diag(theta0[L, :])` corresponds to $$\\sum_{l\\in\\mathcal{L},l\\neq h}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}$$.\n", "\n", "- `var1 += np.diag(np.multiply(B, B))` corresponds to $$\\sum_{t=l_d+1}^{f}\\text{diag}(\\boldsymbol{x}_{t-h}\\circledast \\boldsymbol{x}_{t-h}).$$\n", "\n", "- `var2 += np.diag(B) @ VarPi[t - np.max(time_lags), :]` corresponds to $$\\sum_{t=l_d+1}^{f}{\\boldsymbol{\\pi}_{t}^{h}}\\circledast \\boldsymbol{x}_{t-h}.$$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from numpy.linalg import inv as inv\n", "\n", "\n", "def TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter):\n", " \"\"\"Temporal Regularized Matrix Factorization, TRMF.\"\"\"\n", " \n", " ## Initialize parameters\n", " W = init_para[\"W\"]\n", " X = init_para[\"X\"]\n", " theta = init_para[\"theta\"]\n", " \n", " ## Set hyperparameters\n", " lambda_w = init_hyper[\"lambda_w\"]\n", " lambda_x = init_hyper[\"lambda_x\"]\n", " lambda_theta = init_hyper[\"lambda_theta\"]\n", " eta = init_hyper[\"eta\"]\n", " \n", " dim1, dim2 = sparse_mat.shape\n", " pos_train = np.where(sparse_mat != 0)\n", " pos_test = np.where((dense_mat != 0) & (sparse_mat == 0))\n", " binary_mat = sparse_mat.copy()\n", " binary_mat[pos_train] = 1\n", " d, rank = theta.shape\n", " \n", " for it in range(maxiter):\n", " ## Update spatial matrix W\n", " for i in range(dim1):\n", " pos0 = np.where(sparse_mat[i, :] != 0)\n", " Xt = X[pos0[0], :]\n", " vec0 = Xt.T @ sparse_mat[i, pos0[0]]\n", " mat0 = inv(Xt.T @ Xt + lambda_w * np.eye(rank))\n", " W[i, :] = mat0 @ vec0\n", " ## Update temporal matrix X\n", " for t in range(dim2):\n", " pos0 = np.where(sparse_mat[:, t] != 0)\n", " Wt = W[pos0[0], :]\n", " Mt = np.zeros((rank, rank))\n", " Nt = np.zeros(rank)\n", " if t < np.max(time_lags):\n", " Pt = np.zeros((rank, rank))\n", " Qt = np.zeros(rank)\n", " else:\n", " Pt = np.eye(rank)\n", " Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])\n", " if t < dim2 - np.min(time_lags):\n", " if t >= np.max(time_lags) and t < dim2 - np.max(time_lags):\n", " index = list(range(0, d))\n", " else:\n", " index = list(np.where((t + time_lags >= np.max(time_lags)) & (t + time_lags < dim2)))[0]\n", " for k in index:\n", " Ak = theta[k, :]\n", " Mt += np.diag(Ak ** 2)\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " Nt += np.multiply(Ak, X[t + time_lags[k], :]\n", " - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))\n", " vec0 = Wt.T @ sparse_mat[pos0[0], t] + lambda_x * Nt + lambda_x * Qt\n", " mat0 = inv(Wt.T @ Wt + lambda_x * Mt + lambda_x * Pt + lambda_x * eta * np.eye(rank))\n", " X[t, :] = mat0 @ vec0\n", " ## Update AR coefficients theta\n", " for k in range(d):\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " mat0 = np.zeros((dim2 - np.max(time_lags), rank))\n", " for L in range(d):\n", " mat0 += X[np.max(time_lags) - time_lags[L] : dim2 - time_lags[L] , :] @ np.diag(theta0[L, :])\n", " VarPi = X[np.max(time_lags) : dim2, :] - mat0\n", " var1 = np.zeros((rank, rank))\n", " var2 = np.zeros(rank)\n", " for t in range(np.max(time_lags), dim2):\n", " B = X[t - time_lags[k], :]\n", " var1 += np.diag(np.multiply(B, B))\n", " var2 += np.diag(B) @ VarPi[t - np.max(time_lags), :]\n", " theta[k, :] = inv(var1 + lambda_theta * np.eye(rank) / lambda_x) @ var2\n", "\n", " mat_hat = W @ X.T\n", " mape = np.sum(np.abs(dense_mat[pos_test] - mat_hat[pos_test]) \n", " / dense_mat[pos_test]) / dense_mat[pos_test].shape[0]\n", " rmse = np.sqrt(np.sum((dense_mat[pos_test] - mat_hat[pos_test]) ** 2)/dense_mat[pos_test].shape[0])\n", " \n", " if (it + 1) % 100 == 0:\n", " print('Iter: {}'.format(it + 1))\n", " print('Imputation MAPE: {:.6}'.format(mape))\n", " print('Imputation RMSE: {:.6}'.format(rmse))\n", " print()\n", " \n", " return mat_hat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Missing Data Imputation\n", "\n", "In the following, we apply the above defined TRMF function to the task of missing data imputation task on the following spatiotemporal multivariate time series datasets/matrices:\n", "\n", "- **Guangzhou data set**: [Guangzhou urban traffic speed data set](https://doi.org/10.5281/zenodo.1205228).\n", "- **Birmingham data set**: [Birmingham parking data set](https://archive.ics.uci.edu/ml/datasets/Parking+Birmingham).\n", "- **Hangzhou data set**: [Hangzhou metro passenger flow data set](https://doi.org/10.5281/zenodo.3145403).\n", "- **Settle data set**: [Seattle freeway traffic speed data set](https://github.com/zhiyongc/Seattle-Loop-Data).\n", "\n", "The original data sets have been adapted into our experiments, and it is now available at the fold of `datasets`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on Guangzhou Speed Data\n", "\n", "**Scenario setting**:\n", "\n", "- Tensor size: $214\\times 61\\times 144$ (road segment, day, time of day)\n", "- Non-random missing (NM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Non-random missing (NM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1])[:, :, np.newaxis] + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 10\n", "- Time lags: {1, 2, 144}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.103494\n", "Imputation RMSE: 4.35683\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.103585\n", "Imputation RMSE: 4.36487\n", "\n", "Running time: 438 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 10\n", "time_lags = np.array([1, 2, 144])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $214\\times 61\\times 144$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 80\n", "- Time lags: {1, 2, 144}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0782638\n", "Imputation RMSE: 3.27797\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0776718\n", "Imputation RMSE: 3.25724\n", "\n", "Running time: 888 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 80\n", "time_lags = np.array([1, 2, 144])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $214\\times 61\\times 144$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 60% missing rate\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.6 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 80\n", "- Time lags: {1, 2, 144}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0867978\n", "Imputation RMSE: 3.59856\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0847537\n", "Imputation RMSE: 3.51839\n", "\n", "Running time: 857 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 80\n", "time_lags = np.array([1, 2, 144])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on Hangzhou Flow Data\n", "\n", "**Scenario setting**:\n", "\n", "- Tensor size: $80\\times 25\\times 108$ (metro station, day, time of day)\n", "- Non-random missing (NM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Non-random missing (NM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1])[:, :, np.newaxis] + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 30\n", "- Time lags: {1, 2, 108}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.257291\n", "Imputation RMSE: 38.1766\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.260397\n", "Imputation RMSE: 37.864\n", "\n", "Running time: 179 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 30\n", "time_lags = np.array([1, 2, 108])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $80\\times 25\\times 108$ (metro station, day, time of day)\n", "- Random missing (RM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 30\n", "- Time lags: {1, 2, 108}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.232536\n", "Imputation RMSE: 35.868\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.232085\n", "Imputation RMSE: 36.6294\n", "\n", "Running time: 153 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 30\n", "time_lags = np.array([1, 2, 108])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $80\\times 25\\times 108$ (metro station, day, time of day)\n", "- Random missing (RM)\n", "- 60% missing rate\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.6 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 30\n", "- Time lags: {1, 2, 108}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.242496\n", "Imputation RMSE: 40.4021\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.244181\n", "Imputation RMSE: 41.1514\n", "\n", "Running time: 138 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 30\n", "time_lags = np.array([1, 2, 108])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on Seattle Speed Data\n", "\n", "**Scenario setting**:\n", "\n", "- Tensor size: $323\\times 28\\times 288$ (road segment, day, time of day)\n", "- Non-random missing (NM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Seattle-data-set/tensor.npz')['arr_0']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Non-random missing (NM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1])[:, :, np.newaxis] + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 10\n", "- Time lags: {1, 2, 288}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0919508\n", "Imputation RMSE: 5.29839\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0919374\n", "Imputation RMSE: 5.29846\n", "\n", "Running time: 391 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 10\n", "time_lags = np.array([1, 2, 288])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $323\\times 28\\times 288$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Seattle-data-set/tensor.npz')['arr_0']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 50\n", "- Time lags: {1, 2, 288}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0618055\n", "Imputation RMSE: 3.79776\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0617049\n", "Imputation RMSE: 3.79509\n", "\n", "Running time: 615 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 50\n", "time_lags = np.array([1, 2, 288])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $323\\times 28\\times 288$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 60% missing rate\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Seattle-data-set/tensor.npz')['arr_0']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.6 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 50\n", "- Time lags: {1, 2, 288}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0651363\n", "Imputation RMSE: 3.93597\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0647982\n", "Imputation RMSE: 3.92264\n", "\n", "Running time: 620 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 50\n", "time_lags = np.array([1, 2, 288])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on London Movement Speed Data\n", "\n", "London movement speed data set is is a city-wide hourly traffic speeddataset collected in London.\n", "\n", "- Collected from 200,000+ road segments.\n", "- 720 time points in April 2019.\n", "- 73% missing values in the original data.\n", "\n", "| Observation rate | $>90\\%$ | $>80\\%$ | $>70\\%$ | $>60\\%$ | $>50\\%$ |\n", "|:------------------|--------:|--------:|--------:|--------:|--------:|\n", "|**Number of roads**| 17,666 | 27,148 | 35,912 | 44,352 | 52,727 |\n", "\n", "\n", "If want to test on the full dataset, you could consider the following setting for masking observations as missing values. \n", "\n", "```python\n", "import numpy as np\n", "np.random.seed(1000)\n", "mask_rate = 0.20\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "pos_obs = np.where(dense_mat != 0)\n", "num = len(pos_obs[0])\n", "sample_ind = np.random.choice(num, size = int(mask_rate * num), replace = False)\n", "sparse_mat = dense_mat.copy()\n", "sparse_mat[pos_obs[0][sample_ind], pos_obs[1][sample_ind]] = 0\n", "```\n", "\n", "Notably, you could also consider to evaluate the model on a subset of the data with the following setting." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(1000)\n", "\n", "missing_rate = 0.4\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "\n", "## Non-random missing (NM)\n", "binary_mat = np.zeros(dense_mat.shape)\n", "random_mat = np.random.rand(dense_mat.shape[0], 30)\n", "for i1 in range(dense_mat.shape[0]):\n", " for i2 in range(30):\n", " binary_mat[i1, i2 * 24 : (i2 + 1) * 24] = np.round(random_mat[i1, i2] + 0.5 - missing_rate)\n", "sparse_mat = dense_mat * binary_mat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 20\n", "- Time lags: {1, 2, 24}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0958622\n", "Imputation RMSE: 2.3367\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0957049\n", "Imputation RMSE: 2.33306\n", "\n", "Running time: 960 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 20\n", "time_lags = np.array([1, 2, 24])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(1000)\n", "\n", "missing_rate = 0.4\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "\n", "## Random missing (RM)\n", "random_mat = np.random.rand(dense_mat.shape[0], dense_mat.shape[1])\n", "binary_mat = np.round(random_mat + 0.5 - missing_rate)\n", "sparse_mat = dense_mat * binary_mat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 20\n", "- Time lags: {1, 2, 24}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0920514\n", "Imputation RMSE: 2.23467\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0919819\n", "Imputation RMSE: 2.23439\n", "\n", "Running time: 1123 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 20\n", "time_lags = np.array([1, 2, 24])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(1000)\n", "\n", "missing_rate = 0.6\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "\n", "## Random missing (RM)\n", "random_mat = np.random.rand(dense_mat.shape[0], dense_mat.shape[1])\n", "binary_mat = np.round(random_mat + 0.5 - missing_rate)\n", "sparse_mat = dense_mat * binary_mat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 20\n", "- Time lags: {1, 2, 24}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0936975\n", "Imputation RMSE: 2.27472\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0936174\n", "Imputation RMSE: 2.27413\n", "\n", "Running time: 951 seconds\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "dim1, dim2 = sparse_mat.shape\n", "rank = 20\n", "time_lags = np.array([1, 2, 24])\n", "d = time_lags.shape[0]\n", "## Initialize parameters\n", "W = 0.1 * np.random.rand(dim1, rank)\n", "X = 0.1 * np.random.rand(dim2, rank)\n", "theta = 0.1 * np.random.rand(d, rank)\n", "init_para = {\"W\": W, \"X\": X, \"theta\": theta}\n", "## Set hyparameters\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 0.03\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter)\n", "end = time.time()\n", "print('Running time: %d seconds'%(end - start))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### License\n", "\n", "
\n", "This work is released under the MIT license.\n", "
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" }, "nbTranslate": { "displayLangs": [ "*" ], "hotkey": "alt-t", "langInMainMenu": true, "sourceLang": "en", "targetLang": "fr", "useGoogleTranslate": true } }, "nbformat": 4, "nbformat_minor": 2 }