{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# About this Notebook\n", "\n", "Temporal Regularized Matrix Factorization (TRMF) is an effective tool for imputing missing data within a given multivariate time series and forecasting time series with missing values. This approach is from the following literature:\n", "\n", "> Hsiang-Fu Yu, Nikhil Rao, Inderjit S. Dhillon, 2016. [**Temporal regularized matrix factorization for high-dimensional time series prediction**](http://www.cs.utexas.edu/~rofuyu/papers/tr-mf-nips.pdf). 30th Conference on Neural Information Processing Systems (*NIPS 2016*), Barcelona, Spain.\n", "\n", "**Acknowledgement**: We would like to thank\n", "\n", "- Antony Masso Lussier (HEC Montreal)\n", "\n", "for providing helpful suggestion and discussion. Thank you!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Quick Run\n", "\n", "This notebook is publicly available for any usage at our data imputation project. Please click [**transdim**](https://github.com/xinychen/transdim).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Organization: Matrix Structure\n", "\n", "In this post, we consider a dataset of $m$ discrete time series $\\boldsymbol{y}_{i}\\in\\mathbb{R}^{f},i\\in\\left\\{1,2,...,m\\right\\}$. The time series may have missing elements. We express spatio-temporal dataset as a matrix $Y\\in\\mathbb{R}^{m\\times f}$ with $m$ rows (e.g., locations) and $f$ columns (e.g., discrete time intervals),\n", "\n", "$$Y=\\left[ \\begin{array}{cccc} y_{11} & y_{12} & \\cdots & y_{1f} \\\\ y_{21} & y_{22} & \\cdots & y_{2f} \\\\ \\vdots & \\vdots & \\ddots & \\vdots \\\\ y_{m1} & y_{m2} & \\cdots & y_{mf} \\\\ \\end{array} \\right]\\in\\mathbb{R}^{m\\times f}.$$\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## TRMF model\n", "\n", "Temporal Regularized Matrix Factorization (TRMF) is an approach to incorporate temporal dependencies into commonly-used matrix factorization model. The temporal dependencies are described among ${\\boldsymbol{x}_t}$ explicitly. Such approach takes the form:\n", "\n", "$$\\boldsymbol{x}_{t}\\approx\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l},$$\n", "where this autoregressive (AR) is specialized by a lag set $\\mathcal{L}=\\left\\{l_1,l_2,...,l_d\\right\\}$ (e.g., $\\mathcal{L}=\\left\\{1,2,144\\right\\}$) and weights $\\boldsymbol{\\theta}_{l}\\in\\mathbb{R}^{r},\\forall l$, and we further define\n", "\n", "$$\\mathcal{R}_{AR}\\left(X\\mid \\mathcal{L},\\Theta,\\eta\\right)=\\frac{1}{2}\\sum_{t=l_d+1}^{f}\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)^\\top\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)+\\frac{\\eta}{2}\\sum_{t=1}^{f}\\boldsymbol{x}_{t}^\\top\\boldsymbol{x}_{t}.$$\n", "\n", "Thus, TRMF-AR is given by solving\n", "\n", "$$\\min_{W,X,\\Theta}\\frac{1}{2}\\underbrace{\\sum_{(i,t)\\in\\Omega}\\left(y_{it}-\\boldsymbol{w}_{i}^T\\boldsymbol{x}_{t}\\right)^2}_{\\text{sum of squared residual errors}}+\\lambda_{w}\\underbrace{\\mathcal{R}_{w}\\left(W\\right)}_{W-\\text{regularizer}}+\\lambda_{x}\\underbrace{\\mathcal{R}_{AR}\\left(X\\mid \\mathcal{L},\\Theta,\\eta\\right)}_{\\text{AR-regularizer}}+\\lambda_{\\theta}\\underbrace{\\mathcal{R}_{\\theta}\\left(\\Theta\\right)}_{\\Theta-\\text{regularizer}}$$\n", "\n", "where $\\mathcal{R}_{w}\\left(W\\right)=\\frac{1}{2}\\sum_{i=1}^{m}\\boldsymbol{w}_{i}^\\top\\boldsymbol{w}_{i}$ and $\\mathcal{R}_{\\theta}\\left(\\Theta\\right)=\\frac{1}{2}\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}^\\top\\boldsymbol{\\theta}_{l}$ are regularization terms." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define TRMF model with `Numpy`\n", "\n", "Observing the optimization problem of TRMF model as mentioned above, we categorize the parameters within this model as **parameters** (i.e., `init_para` in the TRMF function) and **hyperparameters** (i.e., `init_hyper`).\n", "\n", "- **Parameters** include spatial matrix $W$, temporal matrix $X$, and AR coefficients $\\Theta$.\n", "- **Hyperparameters** include weight parameters on some regularizers, i.e., $\\lambda_w$, $\\lambda_x$, $\\lambda_\\theta$, and $\\eta$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How to understand Python code of TRMF?\n", "\n", "#### Update spatial matrix $W$\n", "\n", "We write Python code for updating spatial matrix as follows,\n", "\n", "```python\n", "for i in range(dim1):\n", " pos0 = np.where(sparse_mat[i, :] != 0)\n", " Xt = X[pos0[0], :]\n", " vec0 = Xt.T @ sparse_mat[i, pos0[0]]\n", " mat0 = inv(Xt.T @ Xt + lambda_w * np.eye(rank))\n", " W[i, :] = mat0 @ vec0\n", "```\n", "\n", "For your better understanding of these codes, let us see what happened in each line. Recall that the equation for updating $W$ is\n", "$$\\boldsymbol{w}_{i} \\Leftarrow\\left(\\sum_{t:(i, t) \\in \\Omega} \\boldsymbol{x}_{t} \\boldsymbol{x}_{t}^{T}+\\lambda_{w} I\\right)^{-1} \\sum_{t:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{x}_{t}$$\n", "from the optimizization problem:\n", "$$\\min _{W} \\frac{1}{2} \\underbrace{\\sum_{(i, t) \\in \\Omega}\\left(y_{i t}-\\boldsymbol{w}_{i}^{T} \\boldsymbol{x}_{t}\\right)^{2}}_{\\text {sum of squared residual errors }}+\\frac{1}{2} \\lambda_{w} \\underbrace{\\sum_{i=1}^{m} \\boldsymbol{w}_{i}^{T} \\boldsymbol{w}_{i}}_{\\text{sum of squared entries}}.$$\n", "\n", "As can be seen,\n", "\n", "- `vec0 = Xt.T @ sparse_mat[i, pos0[0]])` corresponds to $$\\sum_{t:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{x}_{t}.$$\n", "\n", "- `mat0 = inv(Xt.T @ Xt + lambda_w * np.eye(rank))` corresponds to $$\\left(\\sum_{t:(i, t) \\in \\Omega} \\boldsymbol{x}_{t} \\boldsymbol{x}_{t}^{T}+\\lambda_{w} I\\right)^{-1}.$$\n", "\n", "- `W[i, :] = mat0 @ vec0` corresponds to the update:\n", "$$\\boldsymbol{w}_{i} \\Leftarrow\\left(\\sum_{t:(i, t) \\in \\Omega} \\boldsymbol{x}_{t} \\boldsymbol{x}_{t}^{T}+\\lambda_{w} I\\right)^{-1} \\sum_{t:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{x}_{t}.$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Update temporal matrix $X$\n", "\n", "We write Python code for updating temporal matrix as follows,\n", "\n", "```python\n", "for t in range(dim2):\n", " pos0 = np.where(sparse_mat[:, t] != 0)\n", " Wt = W[pos0[0], :]\n", " Mt = np.zeros((rank, rank))\n", " Nt = np.zeros(rank)\n", " if t < np.max(time_lags):\n", " Pt = np.zeros((rank, rank))\n", " Qt = np.zeros(rank)\n", " else:\n", " Pt = np.eye(rank)\n", " Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])\n", " if t < dim2 - np.min(time_lags):\n", " if t >= np.max(time_lags) and t < dim2 - np.max(time_lags):\n", " index = list(range(0, d))\n", " else:\n", " index = list(np.where((t + time_lags >= np.max(time_lags)) & (t + time_lags < dim2)))[0]\n", " for k in index:\n", " Ak = theta[k, :]\n", " Mt += np.diag(Ak ** 2)\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " Nt += np.multiply(Ak, X[t + time_lags[k], :]\n", " - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))\n", " vec0 = Wt.T @ sparse_mat[pos0[0], t] + lambda_x * Nt + lambda_x * Qt\n", " mat0 = inv(Wt.T @ Wt + lambda_x * Mt + lambda_x * Pt + lambda_x * eta * np.eye(rank))\n", " X[t, :] = mat0 @ vec0\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These codes seem to be very complicated. Let us first see the optimization problem for getting a closed-form update of $X$:\n", "$$\\min_{W,X,\\Theta}\\frac{1}{2}\\underbrace{\\sum_{(i,t)\\in\\Omega}\\left(y_{it}-\\boldsymbol{w}_{i}^T\\boldsymbol{x}_{t}\\right)^2}_{\\text{sum of squared residual errors}}+\\underbrace{\\frac{1}{2}\\lambda_{x}\\sum_{t=l_d+1}^{f}\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)^\\top\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)+\\frac{1}{2}\\lambda_{x}\\eta\\sum_{t=1}^{f}\\boldsymbol{x}_{t}^\\top\\boldsymbol{x}_{t}}_{\\text{AR-term}}+\\underbrace{\\frac{1}{2}\\lambda_{\\theta}\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}^\\top\\boldsymbol{\\theta}_{l}}_{\\Theta-\\text{term}}.$$\n", "\n", "- For $t=1,...,l_d$, update of $X$ is\n", "$$\\boldsymbol{x}_{t} \\Leftarrow\\left(\\sum_{i:(i, t) \\in \\Omega} \\boldsymbol{w}_{i} \\boldsymbol{w}_{i}^{T}+\\lambda_{x} \\eta I\\right)^{-1} \\sum_{i:(i, t) \\in \\Omega} y_{i t} \\boldsymbol{w}_{i}.$$\n", "- For $t=l_d+1,...,f$, update of $X$ is\n", "$${\\boldsymbol{x}_{t}\\Leftarrow\\left(\\sum_{i:(i,t)\\in\\Omega}\\boldsymbol{w}_{i}\\boldsymbol{w}_{i}^{T}+\\lambda_xI+\\lambda_x\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\text{diag}(\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\theta}_{h})+\\lambda_x\\eta I\\right)^{-1}}{\\left(\\sum_{i:(i,t)\\in\\Omega}y_{it}\\boldsymbol{w}_{i}+\\lambda_x\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}+\\lambda_x\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\psi}_{t+h}\\right)}.$$\n", "\n", "Then, as can be seen,\n", "\n", "- `Mt += np.diag(Ak ** 2)` corresponds to $$\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\text{diag}(\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\theta}_{h}).$$\n", "\n", "- `Nt += np.multiply(Ak, X[t + time_lags[k], :] - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))` corresponds to $$\\sum_{h\\in\\mathcal{L},t+h \\leq T}\\boldsymbol{\\theta}_{h}\\circledast\\boldsymbol{\\psi}_{t+h}.$$\n", "\n", "- `Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])` corresponds to $$\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}.$$\n", "-`X[t, :] = mat0 @ vec0` corresponds to the update of $X$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Update AR coefficients $\\Theta$\n", "\n", "We write Python code for updating temporal matrix as follows,\n", "\n", "```python\n", "for k in range(d):\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " mat0 = np.zeros((dim2 - np.max(time_lags), rank))\n", " for L in range(d):\n", " mat0 += X[np.max(time_lags) - time_lags[L] : dim2 - time_lags[L] , :] @ np.diag(theta0[L, :])\n", " VarPi = X[np.max(time_lags) : dim2, :] - mat0\n", " var1 = np.zeros((rank, rank))\n", " var2 = np.zeros(rank)\n", " for t in range(np.max(time_lags), dim2):\n", " B = X[t - time_lags[k], :]\n", " var1 += np.diag(np.multiply(B, B))\n", " var2 += np.diag(B) @ VarPi[t - np.max(time_lags), :]\n", " theta[k, :] = inv(var1 + lambda_theta * np.eye(rank) / lambda_x) @ var2\n", "```\n", "\n", "For your better understanding of these codes, let us see what happened in each line. Recall that the equation for updating $\\theta$ is\n", "$$\n", "\\color{red} {\\boldsymbol{\\theta}_{h}\\Leftarrow\\left(\\sum_{t=l_d+1}^{f}\\text{diag}(\\boldsymbol{x}_{t-h}\\circledast \\boldsymbol{x}_{t-h})+\\frac{\\lambda_{\\theta}}{\\lambda_x}I\\right)^{-1}\\left(\\sum_{t=l_d+1}^{f}{\\boldsymbol{\\pi}_{t}^{h}}\\circledast \\boldsymbol{x}_{t-h}\\right)}\n", "$$\n", "where $\\boldsymbol{\\pi}_{t}^{h}=\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L},l\\neq h}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}$ from the optimizization problem:\n", "$$\n", "\\min_{\\Theta}\\frac{1}{2}\\lambda_{x}\\underbrace{\\sum_{t=l_d+1}^{f}\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)^\\top\\left(\\boldsymbol{x}_{t}-\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}\\right)}_{\\text{sum of squared residual errors}}+\\frac{1}{2}\\lambda_{\\theta}\\underbrace{\\sum_{l\\in\\mathcal{L}}\\boldsymbol{\\theta}_{l}^\\top\\boldsymbol{\\theta}_{l}}_{\\text{sum of squared entries}}\n", "$$\n", "\n", "As can be seen,\n", "- `mat0 += X[np.max(time_lags) - time_lags[L] : dim2 - time_lags[L] , :] @ np.diag(theta0[L, :])` corresponds to $$\\sum_{l\\in\\mathcal{L},l\\neq h}\\boldsymbol{\\theta}_{l}\\circledast\\boldsymbol{x}_{t-l}$$.\n", "\n", "- `var1 += np.diag(np.multiply(B, B))` corresponds to $$\\sum_{t=l_d+1}^{f}\\text{diag}(\\boldsymbol{x}_{t-h}\\circledast \\boldsymbol{x}_{t-h}).$$\n", "\n", "- `var2 += np.diag(B) @ VarPi[t - np.max(time_lags), :]` corresponds to $$\\sum_{t=l_d+1}^{f}{\\boldsymbol{\\pi}_{t}^{h}}\\circledast \\boldsymbol{x}_{t-h}.$$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "def ar4cast(theta, X, time_lags, multi_step):\n", " dim, rank = X.shape\n", " d = time_lags.shape[0]\n", " X_new = np.append(X, np.zeros((multi_step, rank)), axis = 0)\n", " for t in range(multi_step):\n", " X_new[dim + t, :] = np.einsum('kr, kr -> r', theta, X_new[dim + t - time_lags, :])\n", " return X_new" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def compute_mape(var, var_hat):\n", " return np.sum(np.abs(var - var_hat) / var) / var.shape[0]\n", "\n", "def compute_rmse(var, var_hat):\n", " return np.sqrt(np.sum((var - var_hat) ** 2) / var.shape[0])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from numpy.linalg import inv as inv\n", "\n", "\n", "def TRMF(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter):\n", " \"\"\"Temporal Regularized Matrix Factorization, TRMF.\"\"\"\n", " \n", " ## Initialize parameters\n", " W = init_para[\"W\"]\n", " X = init_para[\"X\"]\n", " theta = init_para[\"theta\"]\n", " \n", " ## Set hyperparameters\n", " lambda_w = init_hyper[\"lambda_w\"]\n", " lambda_x = init_hyper[\"lambda_x\"]\n", " lambda_theta = init_hyper[\"lambda_theta\"]\n", " eta = init_hyper[\"eta\"]\n", " \n", " dim1, dim2 = sparse_mat.shape\n", " pos_train = np.where(sparse_mat != 0)\n", " pos_test = np.where((dense_mat != 0) & (sparse_mat == 0))\n", " binary_mat = sparse_mat.copy()\n", " binary_mat[pos_train] = 1\n", " d, rank = theta.shape\n", " \n", " for it in range(maxiter):\n", " ## Update spatial matrix W\n", " for i in range(dim1):\n", " pos0 = np.where(sparse_mat[i, :] != 0)\n", " Xt = X[pos0[0], :]\n", " vec0 = Xt.T @ sparse_mat[i, pos0[0]]\n", " mat0 = inv(Xt.T @ Xt + lambda_w * np.eye(rank))\n", " W[i, :] = mat0 @ vec0\n", " ## Update temporal matrix X\n", " for t in range(dim2):\n", " pos0 = np.where(sparse_mat[:, t] != 0)\n", " Wt = W[pos0[0], :]\n", " Mt = np.zeros((rank, rank))\n", " Nt = np.zeros(rank)\n", " if t < np.max(time_lags):\n", " Pt = np.zeros((rank, rank))\n", " Qt = np.zeros(rank)\n", " else:\n", " Pt = np.eye(rank)\n", " Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])\n", " if t < dim2 - np.min(time_lags):\n", " if t >= np.max(time_lags) and t < dim2 - np.max(time_lags):\n", " index = list(range(0, d))\n", " else:\n", " index = list(np.where((t + time_lags >= np.max(time_lags)) & (t + time_lags < dim2)))[0]\n", " for k in index:\n", " Ak = theta[k, :]\n", " Mt += np.diag(Ak ** 2)\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " Nt += np.multiply(Ak, X[t + time_lags[k], :]\n", " - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))\n", " vec0 = Wt.T @ sparse_mat[pos0[0], t] + lambda_x * Nt + lambda_x * Qt\n", " mat0 = inv(Wt.T @ Wt + lambda_x * Mt + lambda_x * Pt + lambda_x * eta * np.eye(rank))\n", " X[t, :] = mat0 @ vec0\n", " ## Update AR coefficients theta\n", " for k in range(d):\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " mat0 = np.zeros((dim2 - np.max(time_lags), rank))\n", " for L in range(d):\n", " mat0 += X[np.max(time_lags) - time_lags[L] : dim2 - time_lags[L] , :] @ np.diag(theta0[L, :])\n", " VarPi = X[np.max(time_lags) : dim2, :] - mat0\n", " var1 = np.zeros((rank, rank))\n", " var2 = np.zeros(rank)\n", " for t in range(np.max(time_lags), dim2):\n", " B = X[t - time_lags[k], :]\n", " var1 += np.diag(np.multiply(B, B))\n", " var2 += np.diag(B) @ VarPi[t - np.max(time_lags), :]\n", " theta[k, :] = inv(var1 + lambda_theta * np.eye(rank) / lambda_x) @ var2\n", "\n", " X_new = ar4cast(theta, X, time_lags, multi_step)\n", " mat_new = W @ X_new[- multi_step :, :].T\n", " mat_hat = W @ X.T\n", " mape = np.sum(np.abs(dense_mat[pos_test] - mat_hat[pos_test]) \n", " / dense_mat[pos_test]) / dense_mat[pos_test].shape[0]\n", " rmse = np.sqrt(np.sum((dense_mat[pos_test] - mat_hat[pos_test]) ** 2)/dense_mat[pos_test].shape[0])\n", " \n", " if (it + 1) % 100 == 0:\n", " print('Iter: {}'.format(it + 1))\n", " print('Imputation MAPE: {:.6}'.format(mape))\n", " print('Imputation RMSE: {:.6}'.format(rmse))\n", " print()\n", " mat_hat = np.append(mat_hat, mat_new, axis = 1)\n", " \n", " return mat_hat, W, X_new, theta" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def update_x_partial(sparse_mat, W, X, theta, lambda_x, eta, time_lags, back_step):\n", " \n", " dim2, rank = X.shape\n", " tmax = np.max(time_lags)\n", " for t in range(dim2 - back_step, dim2):\n", " pos0 = np.where(sparse_mat[:, t] != 0)\n", " Wt = W[pos0[0], :]\n", " Mt = np.zeros((rank, rank))\n", " Nt = np.zeros(rank)\n", " if t < tmax:\n", " Pt = np.zeros((rank, rank))\n", " Qt = np.zeros(rank)\n", " else:\n", " Pt = np.eye(rank)\n", " Qt = np.einsum('ij, ij -> j', theta, X[t - time_lags, :])\n", " if t < dim2 - np.min(time_lags):\n", " if t >= tmax and t < dim2 - tmax:\n", " index = list(range(0, d))\n", " else:\n", " index = list(np.where((t + time_lags >= tmax) & (t + time_lags < dim2)))[0]\n", " for k in index:\n", " Ak = theta[k, :]\n", " Mt += np.diag(Ak ** 2)\n", " theta0 = theta.copy()\n", " theta0[k, :] = 0\n", " Nt += np.multiply(Ak, X[t + time_lags[k], :]\n", " - np.einsum('ij, ij -> j', theta0, X[t + time_lags[k] - time_lags, :]))\n", " vec0 = Wt.T @ sparse_mat[pos0[0], t] + lambda_x * Nt + lambda_x * Qt\n", " mat0 = inv(Wt.T @ Wt + lambda_x * Mt + lambda_x * Pt + lambda_x * eta * np.eye(rank))\n", " X[t, :] = mat0 @ vec0\n", " return X" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def TRMF_partial(dense_mat, sparse_mat, init_para, init_hyper, time_lags, maxiter):\n", " \n", " ## Initialize parameters\n", " W = init_para[\"W\"]\n", " X = init_para[\"X\"]\n", " theta = init_para[\"theta\"]\n", " ## Set hyperparameters\n", " lambda_x = init_hyper[\"lambda_x\"]\n", " eta = init_hyper[\"eta\"] \n", " back_step = 10 * multi_step\n", " for it in range(maxiter):\n", " X = update_x_partial(sparse_mat, W, X, theta, lambda_x, eta, time_lags, back_step)\n", " X_new = ar4cast(theta, X, time_lags, multi_step)\n", " mat_hat = W @ X_new[- multi_step :, :].T\n", " mat_hat[mat_hat < 0] = 0\n", " \n", " return mat_hat, W, X_new, theta" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from ipywidgets import IntProgress\n", "from IPython.display import display\n", "\n", "def TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter):\n", " dim1, T = dense_mat.shape\n", " d = time_lags.shape[0]\n", " start_time = T - pred_step\n", " max_count = int(np.ceil(pred_step / multi_step))\n", " mat_hat = np.zeros((dim1, max_count * multi_step))\n", " f = IntProgress(min = 0, max = max_count) # instantiate the bar\n", " display(f) # display the bar\n", " for t in range(max_count):\n", " if t == 0:\n", " init_para = {\"W\": 0.1 * np.random.randn(dim1, rank), \n", " \"X\": 0.1 * np.random.randn(start_time, rank),\n", " \"theta\": 0.1 * np.random.randn(d, rank)}\n", " mat, W, X_new, theta = TRMF(dense_mat[:, 0 : start_time], sparse_mat[:, 0 : start_time], \n", " init_para, init_hyper, time_lags, maxiter)\n", " else:\n", " init_para = {\"W\": W, \"X\": X_new, \"theta\": theta}\n", " mat, W, X_new, theta = TRMF_partial(dense_mat[:, 0 : start_time + t * multi_step], \n", " sparse_mat[:, 0 : start_time + t * multi_step], \n", " init_para, init_hyper, time_lags, maxiter)\n", " mat_hat[:, t * multi_step : (t + 1) * multi_step] = mat[:, - multi_step :]\n", " f.value = t\n", " small_dense_mat = dense_mat[:, start_time : T]\n", " pos = np.where(small_dense_mat != 0)\n", " print('Prediction MAPE: {:.6}'.format(compute_mape(small_dense_mat[pos], mat_hat[pos])))\n", " print('Prediction RMSE: {:.6}'.format(compute_rmse(small_dense_mat[pos], mat_hat[pos])))\n", " print()\n", " return mat_hat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on Guangzhou Speed Data\n", "\n", "**Scenario setting**:\n", "\n", "- Tensor size: $214\\times 61\\times 144$ (road segment, day, time of day)\n", "- Non-random missing (NM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Non-random missing (NM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1])[:, :, np.newaxis] + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 10\n", "- Time lags: {1, 2, 144}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2bc5784af9224a6ab72882c3ff7b92d8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.115188\n", "Imputation RMSE: 4.6092\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.115174\n", "Imputation RMSE: 4.60879\n", "\n", "Prediction MAPE: 0.127976\n", "Prediction RMSE: 4.87849\n", "\n", "Running time: 1011 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d44dcf12e08a496aac04cb83a2dd6c25", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=252)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.115176\n", "Imputation RMSE: 4.60879\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.115166\n", "Imputation RMSE: 4.60855\n", "\n", "Prediction MAPE: 0.128447\n", "Prediction RMSE: 4.88498\n", "\n", "Running time: 1033 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "37013a86d03a476cab31cf2bae5b7de0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=168)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.115173\n", "Imputation RMSE: 4.60868\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.115167\n", "Imputation RMSE: 4.60854\n", "\n", "Prediction MAPE: 0.133175\n", "Prediction RMSE: 5.11059\n", "\n", "Running time: 991 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 144\n", "time_lags = np.array([1, 2, 3, 144, 145, 146, 7 * 144, 7 * 144 + 1, 7 * 144 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $214\\times 61\\times 144$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 10\n", "- Time lags: {1, 2, 144}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "de4ca3cc9ea34390aa7f60f667ba8123", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.113995\n", "Imputation RMSE: 4.56227\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.113972\n", "Imputation RMSE: 4.5614\n", "\n", "Prediction MAPE: 0.127169\n", "Prediction RMSE: 4.83963\n", "\n", "Running time: 1064 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "705521b5f08e4cbd976fa98639b2b4ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=252)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.113994\n", "Imputation RMSE: 4.56227\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.11398\n", "Imputation RMSE: 4.56178\n", "\n", "Prediction MAPE: 0.12824\n", "Prediction RMSE: 4.85103\n", "\n", "Running time: 1061 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8c7469bedefa418f87e7d6c8046ca4a1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=168)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.113997\n", "Imputation RMSE: 4.56235\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.113969\n", "Imputation RMSE: 4.56124\n", "\n", "Prediction MAPE: 0.134086\n", "Prediction RMSE: 5.14488\n", "\n", "Running time: 982 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 144\n", "time_lags = np.array([1, 2, 3, 144, 145, 146, 7 * 144, 7 * 144 + 1, 7 * 144 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $214\\times 61\\times 144$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 60% missing rate\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.6 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 80\n", "- Time lags: {1, 2, 144}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e7f8d54aa7fb4a38a9fb7fbb2d25c20c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.125244\n", "Imputation RMSE: 4.92197\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.125242\n", "Imputation RMSE: 4.92202\n", "\n", "Prediction MAPE: 0.138505\n", "Prediction RMSE: 5.18799\n", "\n", "Running time: 978 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b8e5cf4d91a74b21b8389f7f1e885559", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=252)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.125256\n", "Imputation RMSE: 4.92244\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.125242\n", "Imputation RMSE: 4.92198\n", "\n", "Prediction MAPE: 0.139365\n", "Prediction RMSE: 5.18961\n", "\n", "Running time: 979 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae2875d32d904f8199180821d3e5a35f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=168)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.125255\n", "Imputation RMSE: 4.92228\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.125252\n", "Imputation RMSE: 4.92228\n", "\n", "Prediction MAPE: 0.142974\n", "Prediction RMSE: 5.32899\n", "\n", "Running time: 981 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 144\n", "time_lags = np.array([1, 2, 3, 144, 145, 146, 7 * 144, 7 * 144 + 1, 7 * 144 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import warnings\n", "warnings.simplefilter('ignore')\n", "\n", "tensor = scipy.io.loadmat('../datasets/Guangzhou-data-set/tensor.mat')['tensor']\n", "dense_mat = tensor.reshape([tensor.shape[0], tensor.shape[1] * tensor.shape[2]])\n", "sparse_mat = dense_mat.copy()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e088890ffc7745409adb9e282e4a4c56", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.115997\n", "Prediction RMSE: 4.54619\n", "\n", "Running time: 998 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4e4dfa55305b444288e372e5ee1f9097", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=252)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.119562\n", "Prediction RMSE: 4.61194\n", "\n", "Running time: 1013 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a4dd1153b79e4c05821719a804ddab90", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=168)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.122818\n", "Prediction RMSE: 4.90498\n", "\n", "Running time: 1002 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 144\n", "time_lags = np.array([1, 2, 3, 144, 145, 146, 7 * 144, 7 * 144 + 1, 7 * 144 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on Hangzhou Flow Data\n", "\n", "**Scenario setting**:\n", "\n", "- Tensor size: $80\\times 25\\times 108$ (metro station, day, time of day)\n", "- Non-random missing (NM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Non-random missing (NM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1])[:, :, np.newaxis] + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 30\n", "- Time lags: {1, 2, 108}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "34d7f626baed48b59a644ff839d7e511", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=378)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.251544\n", "Imputation RMSE: 57.7424\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.252459\n", "Imputation RMSE: 57.1047\n", "\n", "Prediction MAPE: 0.255831\n", "Prediction RMSE: 38.6273\n", "\n", "Running time: 367 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2eaba748c27b44b9afe10dd6ba1f5767", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=189)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.251771\n", "Imputation RMSE: 57.7999\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.252862\n", "Imputation RMSE: 57.1221\n", "\n", "Prediction MAPE: 0.28039\n", "Prediction RMSE: 40.4237\n", "\n", "Running time: 366 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ad9b91f32efc48dc84a0db8f54b7a8ba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=126)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.251875\n", "Imputation RMSE: 57.5126\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.25278\n", "Imputation RMSE: 57.1189\n", "\n", "Prediction MAPE: 0.276867\n", "Prediction RMSE: 42.6404\n", "\n", "Running time: 365 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 108\n", "time_lags = np.array([1, 2, 3, 108, 109, 110, 7 * 108, 7 * 108 + 1, 7 * 108 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $80\\times 25\\times 108$ (metro station, day, time of day)\n", "- Random missing (RM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 30\n", "- Time lags: {1, 2, 108}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dcc910025457490991137338e3041fe4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=378)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.231354\n", "Imputation RMSE: 50.8092\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.231429\n", "Imputation RMSE: 50.6602\n", "\n", "Prediction MAPE: 0.238024\n", "Prediction RMSE: 35.7663\n", "\n", "Running time: 365 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "63db82685da84570947866ddcd3c27ca", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=189)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.231151\n", "Imputation RMSE: 50.6426\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.231282\n", "Imputation RMSE: 50.5112\n", "\n", "Prediction MAPE: 0.273416\n", "Prediction RMSE: 38.3588\n", "\n", "Running time: 371 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7d3fd9cba5e94007a5c0b84e5202704b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=126)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.230971\n", "Imputation RMSE: 50.5664\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.231013\n", "Imputation RMSE: 50.4308\n", "\n", "Prediction MAPE: 0.273411\n", "Prediction RMSE: 40.3949\n", "\n", "Running time: 368 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 108\n", "time_lags = np.array([1, 2, 3, 108, 109, 110, 7 * 108, 7 * 108 + 1, 7 * 108 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $80\\times 25\\times 108$ (metro station, day, time of day)\n", "- Random missing (RM)\n", "- 60% missing rate\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.6 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 30\n", "- Time lags: {1, 2, 108}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d798b04e435e4b0b808d1e3112be06c3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=378)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.245918\n", "Imputation RMSE: 58.3018\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.245755\n", "Imputation RMSE: 58.3847\n", "\n", "Prediction MAPE: 0.258238\n", "Prediction RMSE: 41.9281\n", "\n", "Running time: 368 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e35ed609fb44a82acbfed8c01917104", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=189)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.245739\n", "Imputation RMSE: 58.1084\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.24575\n", "Imputation RMSE: 58.2141\n", "\n", "Prediction MAPE: 0.25099\n", "Prediction RMSE: 44.2074\n", "\n", "Running time: 367 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6372cbff9bdb4912bb0d3cabdd2927e6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=126)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.245843\n", "Imputation RMSE: 58.1553\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.24597\n", "Imputation RMSE: 58.2217\n", "\n", "Prediction MAPE: 0.268117\n", "Prediction RMSE: 46.4623\n", "\n", "Running time: 368 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 108\n", "time_lags = np.array([1, 2, 3, 108, 109, 110, 7 * 108, 7 * 108 + 1, 7 * 108 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import scipy.io\n", "import warnings\n", "warnings.simplefilter('ignore')\n", "\n", "tensor = scipy.io.loadmat('../datasets/Hangzhou-data-set/tensor.mat')['tensor']\n", "dense_mat = tensor.reshape([tensor.shape[0], tensor.shape[1] * tensor.shape[2]])\n", "sparse_mat = dense_mat.copy()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b96042d1b0ef4f78a20a25aa90d09200", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=378)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.224683\n", "Prediction RMSE: 30.5755\n", "\n", "Running time: 372 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e3e8ce11f8748caa962b68b2ec111ff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=189)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.241361\n", "Prediction RMSE: 32.6289\n", "\n", "Running time: 370 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "63099e610c304a2d8f7dbd56183f2a72", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=126)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.25415\n", "Prediction RMSE: 33.9146\n", "\n", "Running time: 370 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 108\n", "time_lags = np.array([1, 2, 3, 108, 109, 110, 7 * 108, 7 * 108 + 1, 7 * 108 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on Seattle Speed Data\n", "\n", "**Scenario setting**:\n", "\n", "- Tensor size: $323\\times 28\\times 288$ (road segment, day, time of day)\n", "- Non-random missing (NM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Seattle-data-set/tensor.npz')['arr_0']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Non-random missing (NM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1])[:, :, np.newaxis] + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 10\n", "- Time lags: {1, 2, 288}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dea8c7312a2449c7984bfbf93fc10ed2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=1008)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.101807\n", "Imputation RMSE: 5.56864\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.10178\n", "Imputation RMSE: 5.5675\n", "\n", "Prediction MAPE: 0.116346\n", "Prediction RMSE: 5.98404\n", "\n", "Running time: 1119 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fbb2273f5f047d3b9ebd7c73ccbd78e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.101811\n", "Imputation RMSE: 5.56841\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.101781\n", "Imputation RMSE: 5.56744\n", "\n", "Prediction MAPE: 0.116778\n", "Prediction RMSE: 6.02036\n", "\n", "Running time: 1123 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "93c5462f2b0644ef89f102af3463bf4d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=336)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.101812\n", "Imputation RMSE: 5.56882\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.10178\n", "Imputation RMSE: 5.56738\n", "\n", "Prediction MAPE: 0.119966\n", "Prediction RMSE: 6.17256\n", "\n", "Running time: 1141 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 288\n", "time_lags = np.array([1, 2, 3, 288, 289, 290, 7 * 288, 7 * 288 + 1, 7 * 288 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $323\\times 28\\times 288$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 40% missing rate\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Seattle-data-set/tensor.npz')['arr_0']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.4 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 50\n", "- Time lags: {1, 2, 288}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "006f8237eaa24d1fb0592b448975614d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=1008)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0962245\n", "Imputation RMSE: 5.28646\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.096197\n", "Imputation RMSE: 5.28516\n", "\n", "Prediction MAPE: 0.113667\n", "Prediction RMSE: 5.85433\n", "\n", "Running time: 1113 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b487a92a244f4858b9f59a1d4f14335e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.096216\n", "Imputation RMSE: 5.28609\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0961997\n", "Imputation RMSE: 5.28534\n", "\n", "Prediction MAPE: 0.115298\n", "Prediction RMSE: 5.95917\n", "\n", "Running time: 1125 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "75183f15618a4244a5c02e2db7b6186e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=336)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0962624\n", "Imputation RMSE: 5.28802\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0962058\n", "Imputation RMSE: 5.28551\n", "\n", "Prediction MAPE: 0.119754\n", "Prediction RMSE: 6.14716\n", "\n", "Running time: 1135 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 288\n", "time_lags = np.array([1, 2, 3, 288, 289, 290, 7 * 288, 7 * 288 + 1, 7 * 288 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Scenario setting**:\n", "\n", "- Tensor size: $323\\times 28\\times 288$ (road segment, day, time of day)\n", "- Random missing (RM)\n", "- 60% missing rate\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "import time\n", "import scipy.io\n", "import numpy as np\n", "np.random.seed(1000)\n", "\n", "dense_tensor = scipy.io.loadmat('../datasets/Seattle-data-set/tensor.npz')['arr_0']\n", "dim = dense_tensor.shape\n", "missing_rate = 0.6 # Random missing (RM)\n", "sparse_tensor = dense_tensor * np.round(np.random.rand(dim[0], dim[1], dim[2]) + 0.5 - missing_rate)\n", "dense_mat = dense_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "sparse_mat = sparse_tensor.reshape([dim[0], dim[1] * dim[2]])\n", "del dense_tensor, sparse_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 50\n", "- Time lags: {1, 2, 288}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b230161ef491492eb432246fb53180ae", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=1008)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.106593\n", "Imputation RMSE: 5.66466\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.106592\n", "Imputation RMSE: 5.66458\n", "\n", "Prediction MAPE: 0.123772\n", "Prediction RMSE: 6.19303\n", "\n", "Running time: 1090 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5141342081e247c3bde65a019264edbd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.106616\n", "Imputation RMSE: 5.66563\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.106596\n", "Imputation RMSE: 5.66472\n", "\n", "Prediction MAPE: 0.126678\n", "Prediction RMSE: 6.36595\n", "\n", "Running time: 1103 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5e1653f2d22b426080a8b9b850741bfd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=336)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.106616\n", "Imputation RMSE: 5.66558\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.106592\n", "Imputation RMSE: 5.66445\n", "\n", "Prediction MAPE: 0.128849\n", "Prediction RMSE: 6.41523\n", "\n", "Running time: 1105 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 288\n", "time_lags = np.array([1, 2, 3, 288, 289, 290, 7 * 288, 7 * 288 + 1, 7 * 288 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import warnings\n", "warnings.simplefilter('ignore')\n", "\n", "dense_mat = pd.read_csv('../datasets/Seattle-data-set/mat.csv', index_col = 0)\n", "dense_mat = dense_mat.values\n", "sparse_mat = dense_mat.copy()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a1e588239ddc404bbd8dcfd2eb131401", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=1008)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.104488\n", "Prediction RMSE: 5.58347\n", "\n", "Running time: 1157 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1f6fc9d8b6df4bba9c7c712fad6ca6ab", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=504)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.106473\n", "Prediction RMSE: 5.70034\n", "\n", "Running time: 1162 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "553a0f06f62f4dbaa8ab856402d16c7d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=336)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.111521\n", "Prediction RMSE: 5.91872\n", "\n", "Running time: 1181 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 288\n", "time_lags = np.array([1, 2, 3, 288, 289, 290, 7 * 288, 7 * 288 + 1, 7 * 288 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on London Movement Speed Data\n", "\n", "London movement speed data set is is a city-wide hourly traffic speeddataset collected in London.\n", "\n", "- Collected from 200,000+ road segments.\n", "- 720 time points in April 2019.\n", "- 73% missing values in the original data.\n", "\n", "| Observation rate | $>90\\%$ | $>80\\%$ | $>70\\%$ | $>60\\%$ | $>50\\%$ |\n", "|:------------------|--------:|--------:|--------:|--------:|--------:|\n", "|**Number of roads**| 17,666 | 27,148 | 35,912 | 44,352 | 52,727 |\n", "\n", "\n", "If want to test on the full dataset, you could consider the following setting for masking observations as missing values. \n", "\n", "```python\n", "import numpy as np\n", "np.random.seed(1000)\n", "mask_rate = 0.20\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "pos_obs = np.where(dense_mat != 0)\n", "num = len(pos_obs[0])\n", "sample_ind = np.random.choice(num, size = int(mask_rate * num), replace = False)\n", "sparse_mat = dense_mat.copy()\n", "sparse_mat[pos_obs[0][sample_ind], pos_obs[1][sample_ind]] = 0\n", "```\n", "\n", "Notably, you could also consider to evaluate the model on a subset of the data with the following setting." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(1000)\n", "\n", "missing_rate = 0.4\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "\n", "## Non-random missing (NM)\n", "binary_mat = np.zeros(dense_mat.shape)\n", "random_mat = np.random.rand(dense_mat.shape[0], 30)\n", "for i1 in range(dense_mat.shape[0]):\n", " for i2 in range(30):\n", " binary_mat[i1, i2 * 24 : (i2 + 1) * 24] = np.round(random_mat[i1, i2] + 0.5 - missing_rate)\n", "sparse_mat = dense_mat * binary_mat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 20\n", "- Time lags: {1, 2, 24}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f12a687720e74caaa8e929b1c03ee002", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=84)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0993481\n", "Imputation RMSE: 2.40038\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0993169\n", "Imputation RMSE: 2.39941\n", "\n", "Prediction MAPE: 0.116247\n", "Prediction RMSE: 2.72127\n", "\n", "Running time: 1184 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ef983abf90974c7299fd0b123745ca8b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=42)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0993366\n", "Imputation RMSE: 2.40001\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0993201\n", "Imputation RMSE: 2.39946\n", "\n", "Prediction MAPE: 0.127815\n", "Prediction RMSE: 3.14071\n", "\n", "Running time: 1188 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8f77dbb413e245499616b15a2594ea68", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=28)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0993521\n", "Imputation RMSE: 2.40041\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0993202\n", "Imputation RMSE: 2.39952\n", "\n", "Prediction MAPE: 0.123015\n", "Prediction RMSE: 2.86905\n", "\n", "Running time: 1188 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 24\n", "time_lags = np.array([1, 2, 3, 24, 25, 26, 7 * 24, 7 * 24 + 1, 7 * 24 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(1000)\n", "\n", "missing_rate = 0.4\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "\n", "## Random missing (RM)\n", "random_mat = np.random.rand(dense_mat.shape[0], dense_mat.shape[1])\n", "binary_mat = np.round(random_mat + 0.5 - missing_rate)\n", "sparse_mat = dense_mat * binary_mat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 20\n", "- Time lags: {1, 2, 24}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bad673bee5b448ab810ea579abea4126", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=84)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0972497\n", "Imputation RMSE: 2.34772\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0972339\n", "Imputation RMSE: 2.34705\n", "\n", "Prediction MAPE: 0.114902\n", "Prediction RMSE: 2.70622\n", "\n", "Running time: 1168 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9e437f8ce51f4e8397a53a2f34e1c5d6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=42)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0972483\n", "Imputation RMSE: 2.34754\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0972287\n", "Imputation RMSE: 2.34688\n", "\n", "Prediction MAPE: 0.117015\n", "Prediction RMSE: 2.76467\n", "\n", "Running time: 1207 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e1d31bd3074e41d396abf12b1251f06b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=28)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.0972588\n", "Imputation RMSE: 2.34771\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.0972342\n", "Imputation RMSE: 2.34695\n", "\n", "Prediction MAPE: 0.125986\n", "Prediction RMSE: 2.94896\n", "\n", "Running time: 1239 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 24\n", "time_lags = np.array([1, 2, 3, 24, 25, 26, 7 * 24, 7 * 24 + 1, 7 * 24 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(1000)\n", "\n", "missing_rate = 0.6\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "\n", "## Random missing (RM)\n", "random_mat = np.random.rand(dense_mat.shape[0], dense_mat.shape[1])\n", "binary_mat = np.round(random_mat + 0.5 - missing_rate)\n", "sparse_mat = dense_mat * binary_mat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model setting**:\n", "\n", "- Low rank: 20\n", "- Time lags: {1, 2, 24}\n", "- The number of iterations: 200" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a27180e7fc74366b2be42e4441faa37", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=84)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.102163\n", "Imputation RMSE: 2.45197\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.102118\n", "Imputation RMSE: 2.45057\n", "\n", "Prediction MAPE: 0.119481\n", "Prediction RMSE: 2.7947\n", "\n", "Running time: 1061 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7b66d65e3c03449ab29b95b4503d2f6e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=42)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.102112\n", "Imputation RMSE: 2.45041\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.102092\n", "Imputation RMSE: 2.44987\n", "\n", "Prediction MAPE: 0.120207\n", "Prediction RMSE: 2.79563\n", "\n", "Running time: 1072 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0bda89fb94e44bd1ac62623ac658bf96", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=28)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: 0.102122\n", "Imputation RMSE: 2.45049\n", "\n", "Iter: 200\n", "Imputation MAPE: 0.1021\n", "Imputation RMSE: 2.44998\n", "\n", "Prediction MAPE: 0.126814\n", "Prediction RMSE: 2.94801\n", "\n", "Running time: 1066 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 24\n", "time_lags = np.array([1, 2, 3, 24, 25, 26, 7 * 24, 7 * 24 + 1, 7 * 24 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import warnings\n", "warnings.simplefilter('ignore')\n", "\n", "dense_mat = np.load('../datasets/London-data-set/hourly_speed_mat.npy')\n", "binary_mat = dense_mat.copy()\n", "binary_mat[binary_mat != 0] = 1\n", "pos = np.where(np.sum(binary_mat, axis = 1) > 0.7 * binary_mat.shape[1])\n", "dense_mat = dense_mat[pos[0], :]\n", "sparse_mat = dense_mat.copy()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction time horizon (delta) = 2.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a5ed24263b8642349a2066f5f882401f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=84)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.113515\n", "Prediction RMSE: 2.67703\n", "\n", "Running time: 1452 seconds\n", "\n", "Prediction time horizon (delta) = 4.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0a68e4fd2e0147a2b4502c062902b10a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=42)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.116352\n", "Prediction RMSE: 2.79261\n", "\n", "Running time: 1420 seconds\n", "\n", "Prediction time horizon (delta) = 6.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bba3879155ab4cc99621fc72be38caff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, max=28)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Iter: 100\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Iter: 200\n", "Imputation MAPE: nan\n", "Imputation RMSE: nan\n", "\n", "Prediction MAPE: 0.127962\n", "Prediction RMSE: 3.04552\n", "\n", "Running time: 1420 seconds\n", "\n" ] } ], "source": [ "import time\n", "start = time.time()\n", "rank = 10\n", "pred_step = 7 * 24\n", "time_lags = np.array([1, 2, 3, 24, 25, 26, 7 * 24, 7 * 24 + 1, 7 * 24 + 2])\n", "lambda_w = 500\n", "lambda_x = 500\n", "lambda_theta = 500\n", "eta = 1\n", "init_hyper = {\"lambda_w\": lambda_w, \"lambda_x\": lambda_x, \"lambda_theta\": lambda_theta, \"eta\": eta}\n", "maxiter = 200\n", "for multi_step in [2, 4, 6]:\n", " start = time.time()\n", " print('Prediction time horizon (delta) = {}.'.format(multi_step))\n", " mat_hat = TRMF_forecast(dense_mat, sparse_mat, init_hyper, pred_step, multi_step, rank, time_lags, maxiter)\n", " end = time.time()\n", " print('Running time: %d seconds'%(end - start))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### License\n", "\n", "
\n", "This work is released under the MIT license.\n", "
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" }, "nbTranslate": { "displayLangs": [ "*" ], "hotkey": "alt-t", "langInMainMenu": true, "sourceLang": "en", "targetLang": "fr", "useGoogleTranslate": true } }, "nbformat": 4, "nbformat_minor": 2 }