{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Missing Data Imputation\n",
    "\n",
    "Missing data imputation implies filling the missing entries in our data, which can be tabular data, images, time series, with \"reasonable\" values. Ideally, the imputed values should be a good guess of the values that are missing, but that we cannot observe.\n",
    "\n",
    "Very simple techniques for imputing missing values are:\n",
    "- `zero imputation`: replace the missing values with a $0$.\n",
    "- `mean imputation`: replace with the mean value of a given feature in the dataset.\n",
    "- `last value carried-forward`, also called `forward-filling`, replaces the missing value with the last observed one.\n",
    "\n",
    "In this notebook we see how to perform imputation of missing data using a Reservoir. Hopefully, we will manage to impute missing values more precisely than the simple baseline approaches.\n",
    "\n",
    "We start by defining some utility functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import Ridge\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.impute import SimpleImputer\n",
    "\n",
    "# Local imports\n",
    "from reservoir_computing.datasets import ClfLoader\n",
    "from reservoir_computing.reservoir import Reservoir\n",
    "\n",
    "np.random.seed(0) # for reproducibility\n",
    "\n",
    "def plot_missing_data(data):\n",
    "    \"\"\"\n",
    "    Plots the missing data mask.\n",
    "    Red = missing.\n",
    "    Green = data available.\n",
    "    \"\"\"\n",
    "\n",
    "    data = data.T\n",
    "    missing_mask = ~np.isnan(data)\n",
    "\n",
    "    _, ax = plt.subplots(figsize=(8, 3))\n",
    "    _ = ax.imshow(missing_mask, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)\n",
    "    ax.set_title(\"Missing Data Visualization\", fontsize=16)\n",
    "    ax.set_xlabel(\"Time Steps\", fontsize=14)\n",
    "    ax.set_ylabel(\"Features\", fontsize=14)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "def forward_fill_timewise(data_3d):\n",
    "    \"\"\"\n",
    "    Replace the missing values with the last observed value\n",
    "    along the time dimension.\n",
    "    The input `data_3d` is a 3-dimensional array of shape [N, T, V].\n",
    "    If the missing value is at the first time step, replace it with zero.\n",
    "    \"\"\"\n",
    "    N, T, V = data_3d.shape\n",
    "    for n in range(N):\n",
    "        for v in range(V):\n",
    "            for t in range(T):\n",
    "                if np.isnan(data_3d[n, t, v]):\n",
    "                    if t == 0:\n",
    "                        data_3d[n, t, v] = 0.0\n",
    "                    else:\n",
    "                        data_3d[n, t, v] = data_3d[n, t - 1, v]\n",
    "    return data_3d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## All features missing\n",
    "\n",
    "We will consider two settings. In the first, all the features are missing from the time series at the same time.\n",
    "\n",
    "In a second setting, that we will see later, only a subset of the features might be missing in the time series at a given time.\n",
    "\n",
    "### Generate missing values\n",
    "\n",
    "Next we load the data and add the missing values.\n",
    "\n",
    "We will use a set of multivariate time series, which we store in an array of size $[N, T, V]$, where $N$ denotes the number of samples, $T$ the length of the time series, and $V$ the number of variables. \n",
    "\n",
    "We also normalize the data with a standard scaler, so that the values are centered around $0$ with standard deviation $1$. This is important because if we pass high values to the Reservoir its activations saturate.\n",
    "\n",
    "We will add missing values using two different patterns.\n",
    "- `point`: with a probability $p_\\text{point}$ (`p_missing_point`), the value $X[n,t,v]$ is missing.\n",
    "- `block`: with a probability $p_\\text{block}$ (`p_missing_block`) and a window of length $w$ (`duration_block`), the values $X[n,t:t+w,v]$ are missing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded Japanese_Vowels dataset.\n",
      "Number of classes: 9\n",
      "Data shapes:\n",
      "  Xtr: (270, 29, 12)\n",
      "  Ytr: (270, 1)\n",
      "  Xte: (370, 29, 12)\n",
      "  Yte: (370, 1)\n",
      "Number of missing values: 116688 (0.52%).\n"
     ]
    }
   ],
   "source": [
    "# Load the data\n",
    "Xtr, _, Xte, _ = ClfLoader().get_data('Japanese_Vowels')\n",
    "X = np.concatenate([Xtr, Xte], axis=0)  \n",
    "N, T, V = X.shape\n",
    "\n",
    "# Normalize data\n",
    "scaler = StandardScaler()\n",
    "X = scaler.fit_transform(X.reshape(X.shape[0], -1)).reshape(X.shape)\n",
    "\n",
    "# Parameters for missingness\n",
    "p_missing_point = 0.2\n",
    "p_missing_block = 0.1\n",
    "duration_block = 5\n",
    "\n",
    "# Add missing values\n",
    "X_missing = X.copy()\n",
    "for i in range(N):\n",
    "    # Random point missing\n",
    "    point_mask = np.random.rand(T) < p_missing_point\n",
    "    X_missing[i, point_mask, :] = np.nan\n",
    "\n",
    "    # Random block missing\n",
    "    block_mask = np.random.rand(T) < p_missing_block\n",
    "    for j in range(T):\n",
    "        if block_mask[j]:\n",
    "            end_idx = min(j + duration_block, T)\n",
    "            X_missing[i, j:end_idx, :] = np.nan\n",
    "\n",
    "print(f\"Number of missing values: {np.sum(np.isnan(X_missing))} \"\n",
    "    f\"({np.sum(np.isnan(X_missing)) / X_missing.size:.2f}%).\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the utility function we defined before to visualize the pattern of missing data that we generated in a given data sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 800x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_missing_data(X_missing[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will fill the missing values using forward-filling imputation. This is usually a good baseline for time series because, especially if the window of missing values is short, we can assume that the missing values will not be so different from the last observed value.\n",
    "\n",
    "Note that we do this imputation variable-wise, i.e., we carry forward the last observed value at each variable. \n",
    "\n",
    "This imputation is also necessary to get meaningful states out of the Reservoir and to prevent having `nan` in the sequence of Reservoir states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Missing values after forward fill: 0\n"
     ]
    }
   ],
   "source": [
    "# Forward-fill the missing entries \n",
    "X_missing_filled = forward_fill_timewise(X_missing.copy())\n",
    "print(\"Missing values after forward fill:\", np.isnan(X_missing_filled).sum()) # should be 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute Reservoir states\n",
    "\n",
    "We are now ready to build our Reservoir and feed it with the time series to get the states.\n",
    "\n",
    "Importantly, we will use the Reservoir states to perform forecasting, so do **not** want to use a bidirectional readout that also processes the time series backward.\n",
    "Thus, we must set `bidir=False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "states.shape: (640, 29, 700)\n",
      "Missing values in states: 0\n"
     ]
    }
   ],
   "source": [
    "# Initialize the Reservoir\n",
    "res = Reservoir(\n",
    "    n_internal_units=700,\n",
    "    spectral_radius=0.7,\n",
    "    leak=0.7,\n",
    "    connectivity=0.2,\n",
    "    input_scaling=0.05)\n",
    "\n",
    "# Compute the Reservoir states\n",
    "states = res.get_states(X_missing_filled, bidir=False)  # shape: [N, T, H]\n",
    "_, _, H = states.shape\n",
    "print(\"states.shape:\", states.shape)\n",
    "print(\"Missing values in states:\", np.isnan(states).sum()) # should be 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To perform imputation, we will train a readout to do a forecasting task, i.e., given the Readout state $h(t)$ the readout should predict $x(t+h)$ where $h$ is the forecast `horizon`.\n",
    "\n",
    "The reason why we do forecasting, rather than just mapping $h(t)$ into $x(t)$ is that in this way the readout just learns to copy in output the input values and does not know what to do when the input value (at inference time) will be missing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After shifting:\n",
      " states => (640, 28, 700),\n",
      " X_missing_future => (640, 28, 12)\n"
     ]
    }
   ],
   "source": [
    "horizon = 1  # forecast horizon\n",
    "states = states[:, :T - horizon, :] # drop the last 'horizon' steps\n",
    "X_missing_future = X_missing[:, horizon:, :] # drop the first 'horizon' steps\n",
    "\n",
    "_, T_new, _ = X_missing_future.shape\n",
    "print(f\"After shifting:\\n states => {states.shape},\\n X_missing_future => {X_missing_future.shape}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before training the readout we have to flatten the states and the inputs into 2-dimensional arrays, because we have to map a single $H$-dimensional vector into a $V$-dimensional one. \n",
    "\n",
    "To flatten the data, we concatenate the sample and the time dimensions. This is OK, because the Reservoir should have embedded the historical information into its states, which can then be treated independently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Flatten the states and the input values\n",
    "states_2d = states.reshape(N * T_new, H)\n",
    "X_missing_future_2d = X_missing_future.reshape(N * T_new, V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the readout\n",
    "\n",
    "Clearly, we can train the readout only on the instances $\\{h(t), x(t+h)\\}$ where the target $x(t+h)$ is not missing. So, we have to first filter out all the time steps where the target is missing.\n",
    "\n",
    "Then, we will define the readout and fit it to the training data that we defined. We will use a simple Ridge Regressor but, of course, more sophisticated regressors can be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train samples (rows) with no missing output: 8387\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Ridge(alpha=1)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;Ridge<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.linear_model.Ridge.html\">?<span>Documentation for Ridge</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>Ridge(alpha=1)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "Ridge(alpha=1)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Keep only rows with no missing target\n",
    "not_missing_mask = ~np.isnan(X_missing_future_2d).any(axis=1)\n",
    "X_train = states_2d[not_missing_mask]\n",
    "Y_train = X_missing_future_2d[not_missing_mask]\n",
    "print(\"Train samples (rows) with no missing output:\", X_train.shape[0])\n",
    "\n",
    "# Create and train the readout\n",
    "model = Ridge(alpha=1)\n",
    "model.fit(X_train, Y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Impute missing values and evaluate performances\n",
    "\n",
    "Once the readout is trained, we can use it to predict **all** the target values, including those with missing values that we excluded from the training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_2d = model.predict(states_2d)  # shape: [N*T_new, V]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, we will replace all the missing values in `X_missing_future` with the predicted values, which represent the imputed ones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fill with the missing values\n",
    "X_imputed_future_2d = X_missing_future_2d.copy()\n",
    "missing_mask_2d = np.isnan(X_imputed_future_2d)\n",
    "X_imputed_future_2d[missing_mask_2d] = predictions_2d[missing_mask_2d]\n",
    "\n",
    "# Reshape back to the original shape\n",
    "X_imputed_future = X_imputed_future_2d.reshape(N, T_new, V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sine our imputation is based on prediction, we do not have imputed values for the first `horizon` steps. So, we will take those steps from `X_missing_filled`, while the rest from `X_imputed_future`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_imputed = X_missing_filled.copy()  # start with forward-filled\n",
    "X_imputed[:, horizon:, :] = X_imputed_future"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Done! At this point, we can check the performance of our imputation method by computing the MSE between the time seris with imputed values and the original one where the values are not missing.\n",
    "\n",
    "We can also compare against the imputation obtained with the other simple baselines.\n",
    "\n",
    "Note that in this case zero- and mean-imputations give the same result because the data are normalized with a standard scaler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE (Reservoir imp): 0.2500\n",
      "MSE (f-fill imp): 0.3843\n",
      "MSE (zero imp): 0.4452\n",
      "MSE (mean imp): 0.4454\n"
     ]
    }
   ],
   "source": [
    "# Compute zero-imputed values\n",
    "X_zero_imputed = X_missing.copy()\n",
    "X_zero_imputed[np.isnan(X_zero_imputed)] = 0.0\n",
    "\n",
    "# Compute mean-imputation with SimpleImputer\n",
    "imp = SimpleImputer(strategy='mean')\n",
    "X_mean_imputed = imp.fit_transform(X_missing.reshape(N * T, V)).reshape(N, T, V)\n",
    "\n",
    "# Print the MSE between the true and the imputed values\n",
    "mse_imputed = np.mean((X - X_imputed)**2)\n",
    "print(f\"MSE (Reservoir imp): {mse_imputed:.4f}\")\n",
    "mse_filled = np.mean((X - X_missing_filled)**2)\n",
    "print(f\"MSE (f-fill imp): {mse_filled:.4f}\")\n",
    "mse_zero_imputed = np.mean((X - X_zero_imputed)**2)\n",
    "print(f\"MSE (zero imp): {mse_zero_imputed:.4f}\")\n",
    "mse_mean_imputed = np.mean((X - X_mean_imputed)**2)\n",
    "print(f\"MSE (mean imp): {mse_mean_imputed:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also make a visualization of the true and imputed values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 1000x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sample_idx, feature_idx = 2,9\n",
    "\n",
    "plt.figure(figsize=(10, 4))\n",
    "plt.plot(X[sample_idx, :, feature_idx], 'o-', label='Original')\n",
    "plt.plot(X_missing_filled[sample_idx, :, feature_idx], 'x--', label='ffill imputed')\n",
    "plt.plot(X_zero_imputed[sample_idx, :, feature_idx], 'x--', label='zero imputed')\n",
    "plt.plot(X_imputed[sample_idx, :, feature_idx], 's--', label='Reservoir imputed')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Partial missing features\n",
    "\n",
    "In this setting, we allow only a subset of the features to be missing at each time step. This can happen, for example, when we have $V$ independet sensors and only a few of them are not collecting the measuremements.\n",
    "\n",
    "### Generate missing values\n",
    "\n",
    "Here, we modify the procedure for generating missing values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 800x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of missing values: 112700 (0.51%).\n"
     ]
    }
   ],
   "source": [
    "p_missing_point = 0.2\n",
    "p_missing_block = 0.1\n",
    "duration_block = 5\n",
    "\n",
    "X_missing = X.copy()\n",
    "for i in range(N):\n",
    "    # Random point missing -- each feature gets its own mask\n",
    "    pmask = np.random.rand(T, V) < p_missing_point\n",
    "    X_missing[i, pmask] = np.nan\n",
    "\n",
    "    # Random block missing -- each feature gets its own blocks\n",
    "    for v in range(V):\n",
    "        block_mask = np.random.rand(T) < p_missing_block\n",
    "        for j in range(T):\n",
    "            if block_mask[j]:\n",
    "                end_idx = min(j + duration_block, T)\n",
    "                X_missing[i, j:end_idx, v] = np.nan\n",
    "\n",
    "plot_missing_data(X_missing[0])\n",
    "\n",
    "print(f\"Number of missing values: {np.sum(np.isnan(X_missing))} \"\n",
    "    f\"({np.sum(np.isnan(X_missing)) / X_missing.size:.2f}%).\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we repeat the same steps as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Forward-fill the missing entries before computing the reservoir states\n",
    "X_missing_filled = forward_fill_timewise(X_missing.copy())\n",
    "\n",
    "# Create the Reservoir\n",
    "res = Reservoir(\n",
    "    n_internal_units=700,\n",
    "    spectral_radius=0.7,\n",
    "    leak=0.7,\n",
    "    connectivity=0.2,\n",
    "    input_scaling=0.05)\n",
    "\n",
    "# Compute the Reservoir states\n",
    "states = res.get_states(X_missing_filled, bidir=False)\n",
    "_, _, H = states.shape\n",
    "\n",
    "# Sift the states and the input values by horizon\n",
    "horizon = 1\n",
    "states = states[:, :T - horizon, :]\n",
    "X_missing_future = X_missing[:, horizon:, :]\n",
    "_, T_new, _ = X_missing_future.shape\n",
    "\n",
    "# Flatten the states and the input values\n",
    "states_2d = states.reshape(N * T_new, H) \n",
    "X_missing_future_2d = X_missing_future.reshape(N * T_new, V) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before, we were mapping the Reservoir state $h(t) \\in \\mathbb{R}^H$ into the future input $x(t+h) \\in \\mathbb{R}^V$, which was either completely missing or completely available.\n",
    "\n",
    "This time, some variables can be missing and other be present, meaning that we need to predict a variable $i$ and a variable $j$ independently.\n",
    "In other words, we need to train $V$ independent readouts, each one mapping the state $h(t)$ into the $v$-th variable $x_v(t+h)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature 0: #train samples = 8815\n",
      "Feature 1: #train samples = 8741\n",
      "Feature 2: #train samples = 8488\n",
      "Feature 3: #train samples = 8854\n",
      "Feature 4: #train samples = 8823\n",
      "Feature 5: #train samples = 8856\n",
      "Feature 6: #train samples = 8643\n",
      "Feature 7: #train samples = 8908\n",
      "Feature 8: #train samples = 8752\n",
      "Feature 9: #train samples = 8518\n",
      "Feature 10: #train samples = 8610\n",
      "Feature 11: #train samples = 8510\n"
     ]
    }
   ],
   "source": [
    "predictions_2d = np.zeros_like(X_missing_future_2d) \n",
    "\n",
    "for v in range(V):\n",
    "    # For feature v, identify rows where the target is not missing\n",
    "    not_missing_mask_v = ~np.isnan(X_missing_future_2d[:, v])\n",
    "    \n",
    "    # Training data: states + single feature's target\n",
    "    X_train_v = states_2d[not_missing_mask_v]\n",
    "    y_train_v = X_missing_future_2d[not_missing_mask_v, v]\n",
    "    print(f\"Feature {v}: #train samples = {len(y_train_v)}\")\n",
    "    \n",
    "    # Fit a ridge model (or any regressor) for this feature\n",
    "    model_v = Ridge(alpha=1.0)\n",
    "    model_v.fit(X_train_v, y_train_v)\n",
    "    \n",
    "    # Predict for ALL time steps (including missing) in states_2d\n",
    "    predictions_2d[:, v] = model_v.predict(states_2d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next steps are similar to those taken before: we take the predictions as imputed values and put them in the correct place where the data are missing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_imputed_future_2d = X_missing_future_2d.copy()\n",
    "missing_mask_2d = np.isnan(X_imputed_future_2d)  \n",
    "\n",
    "# Fill with feature-specific predictions\n",
    "X_imputed_future_2d[missing_mask_2d] = predictions_2d[missing_mask_2d]\n",
    "\n",
    "# Reshape back to [N, T_new, V]\n",
    "X_imputed_future = X_imputed_future_2d.reshape(N, T_new, V)\n",
    "\n",
    "#  Use the forward-filled data for the first 'horizon' steps\n",
    "X_imputed = X_missing_filled.copy()  \n",
    "X_imputed[:, horizon:, :] = X_imputed_future"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To conclude, we compare the results obtained with the other imputation methods. Also in this case, we see that we obtain better performance with the Reservoir-based imputation compared to other imputation techniques."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE (Reservoir imp): 0.3154\n",
      "MSE (f-fill imp): 0.4565\n",
      "MSE (zero imp): 0.5047\n"
     ]
    }
   ],
   "source": [
    "# Compute zero-imputed values\n",
    "X_zero_imputed = X_missing.copy()\n",
    "X_zero_imputed[np.isnan(X_zero_imputed)] = 0.0\n",
    "\n",
    "# Print the MSE between the true and the imputed values\n",
    "mse_imputed = np.mean((X - X_imputed)**2)\n",
    "print(f\"MSE (Reservoir imp): {mse_imputed:.4f}\")\n",
    "mse_filled = np.mean((X - X_missing_filled)**2)\n",
    "print(f\"MSE (f-fill imp): {mse_filled:.4f}\")\n",
    "mse_zero_imputed = np.mean((X - X_zero_imputed)**2)\n",
    "print(f\"MSE (zero imp): {mse_zero_imputed:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tgp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}