{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Machine Learning for Time Series (Master MVA)**\n",
    "\n",
    "- [Link to the class material.](http://www.laurentoudre.fr/ast.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "In this notebook, we illustrate the following concept:\n",
    "- graph signal processing."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Import**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from math import asin, cos, radians, sin, sqrt\n",
    "\n",
    "import contextily as cx\n",
    "import geopandas\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from loadmydata.load_molene_meteo import load_molene_meteo_dataset\n",
    "from matplotlib.dates import DateFormatter\n",
    "from pygsp import graphs\n",
    "from scipy.linalg import eigh\n",
    "from scipy.spatial.distance import pdist, squareform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "CRS = \"EPSG:4326\"\n",
    "\n",
    "STATION_LIST = [\n",
    "    \"ARZAL\",\n",
    "    \"AURAY\",\n",
    "    \"BELLE ILE-LE TALUT\",\n",
    "    \"BIGNAN\",\n",
    "    \"BREST-GUIPAVAS\",\n",
    "    \"BRIGNOGAN\",\n",
    "    \"DINARD\",\n",
    "    \"GUERANDE\",\n",
    "    \"ILE DE GROIX\",\n",
    "    \"ILE-DE-BREHAT\",\n",
    "    \"KERPERT\",\n",
    "    \"LANDIVISIAU\",\n",
    "    \"LANNAERO\",\n",
    "    \"LANVEOC\",\n",
    "    \"LORIENT-LANN BIHOUE\",\n",
    "    \"LOUARGAT\",\n",
    "    \"MERDRIGNAC\",\n",
    "    \"NOIRMOUTIER EN\",\n",
    "    \"OUESSANT-STIFF\",\n",
    "    \"PLEUCADEUC\",\n",
    "    \"PLEYBER-CHRIST SA\",\n",
    "    \"PLOERMEL\",\n",
    "    \"PLOUDALMEZEAU\",\n",
    "    \"PLOUGUENAST\",\n",
    "    \"PLOUMANAC'H\",\n",
    "    \"POMMERIT-JAUDY\",\n",
    "    \"PONTIVY\",\n",
    "    \"PTE DE CHEMOULIN\",\n",
    "    \"PTE DE PENMARCH\",\n",
    "    \"PTE DU RAZ\",\n",
    "    \"QUIMPER\",\n",
    "    \"QUINTENIC\",\n",
    "    \"ROSTRENEN\",\n",
    "    \"SAINT-CAST-LE-G\",\n",
    "    \"SARZEAU SA\",\n",
    "    \"SIBIRIL S A\",\n",
    "    \"SIZUN\",\n",
    "    \"SPEZET\",\n",
    "    \"ST BRIEUC\",\n",
    "    \"ST NAZAIRE-MONTOIR\",\n",
    "    \"ST-SEGAL S A\",\n",
    "    \"THEIX\",\n",
    "    \"VANNES-SENE\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Utility functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fig_ax(figsize=(15, 3)):\n",
    "    return plt.subplots(figsize=figsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_line_graph(n_nodes=10) -> graphs.Graph:\n",
    "    \"\"\"Return a line graph.\"\"\"\n",
    "    adjacency_matrix = np.eye(n_nodes)\n",
    "    adjacency_matrix = np.c_[adjacency_matrix[:, -1], adjacency_matrix[:, :-1]]\n",
    "    adjacency_matrix += adjacency_matrix.T\n",
    "    line_graph = graphs.Graph(adjacency_matrix)\n",
    "    line_graph.set_coordinates(kind=\"ring2D\")\n",
    "    line_graph.compute_laplacian(\"combinatorial\")\n",
    "    return line_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_grid_graph(n_nodes_height=10, n_nodes_width=10) -> graphs.Graph:\n",
    "    \"\"\"Return a 2D grid graph.\"\"\"\n",
    "    g = graphs.Grid2d(n_nodes_height, n_nodes_width)\n",
    "    xx, yy = np.meshgrid(np.arange(n_nodes_height), np.arange(n_nodes_width))\n",
    "    coords = np.array((xx.ravel(), yy.ravel())).T\n",
    "    g.set_coordinates(coords)\n",
    "    g.compute_laplacian(\"combinatorial\")\n",
    "    return g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dms2dd(s):\n",
    "    \"\"\"Convert longitude and latitude strings to float.\"\"\"\n",
    "    # https://stackoverflow.com/a/50193328\n",
    "    # example: s =  \"\"\"48°51'18\"\"\"\n",
    "    degrees, minutes, seconds = re.split(\"[°'\\\"]+\", s[:-1])\n",
    "    direction = s[-1]\n",
    "    dd = float(degrees) + float(minutes) / 60 + float(seconds) / (60 * 60)\n",
    "    if direction in (\"S\", \"W\"):\n",
    "        dd *= -1\n",
    "    return dd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_geodesic_distance(point_1, point_2) -> float:\n",
    "    \"\"\"\n",
    "    Calculate the great circle distance (in km) between two points\n",
    "    on the earth (specified in decimal degrees)\n",
    "\n",
    "    https://stackoverflow.com/a/4913653\n",
    "    \"\"\"\n",
    "\n",
    "    lon1, lat1 = point_1\n",
    "    lon2, lat2 = point_2\n",
    "\n",
    "    # convert decimal degrees to radians\n",
    "    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])\n",
    "\n",
    "    # haversine formula\n",
    "    dlon = lon2 - lon1\n",
    "    dlat = lat2 - lat1\n",
    "    a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2\n",
    "    c = 2 * asin(sqrt(a))\n",
    "    r = 6371  # Radius of earth in kilometers. Use 3956 for miles\n",
    "    return c * r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_exponential_similarity(condensed_distance_matrix, bandwidth, threshold):\n",
    "    exp_similarity = np.exp(-(condensed_distance_matrix**2) / bandwidth / bandwidth)\n",
    "    res_arr = np.where(exp_similarity > threshold, exp_similarity, 0.0)\n",
    "    return res_arr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Graph signal processing\n",
    "\n",
    "A graph $G$ is a set of $N$ **nodes** connected with **edges**. A **graph signal** is a $\\mathbb{R}^N$ vector that is supported by the nodes of the graph $G$.\n",
    "\n",
    "Graph Signal Processing (GSP) is the set of methods of methods to study such objects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Toy data\n",
    "\n",
    "Let us illustrate the basic principles of GSP on two toy graphs: the line graph and the 2D grid graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "line_graph = get_line_graph(50)  # 50 nodes\n",
    "line_graph.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_graph = get_grid_graph(20, 20)  # 20 by 20 grid\n",
    "grid_graph.plot()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now generate noisy signals on those two graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax_0, ax_1) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "# generate a noisy sinusoid\n",
    "tt = np.linspace(0, 6 * np.pi, line_graph.N)\n",
    "signal_line = 5 * np.sin(tt) + np.random.normal(size=line_graph.N)\n",
    "# plot\n",
    "line_graph.plot_signal(signal_line, ax=ax_0)\n",
    "ax_1.plot(signal_line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax_0, ax_1) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "\n",
    "# generate a noisy sinusoid\n",
    "x = np.linspace(-2 * np.pi, 2 * np.pi, 20)\n",
    "y = np.linspace(-2 * np.pi, 2 * np.pi, 20)\n",
    "xx, yy = np.meshgrid(x, y, sparse=True)\n",
    "z = 5 * np.sin(xx + yy)\n",
    "z += np.random.normal(size=z.shape)\n",
    "signal_grid = z.flatten()\n",
    "# plot\n",
    "ax_0.contourf(x, y, z)\n",
    "grid_graph.plot_signal(signal_grid, ax=ax_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fourier basis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Recall the definition of the Laplacian matrix.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Compute the Laplacian matrix for both the line graph and the grid graph.</p>\n",
    "    <p>Verify your result with the Laplacian matrix provided by the <tt>Graph</tt> class (available in <tt>g.L.todense()</tt> for a graph <tt>g</tt>).</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the GSP setting, the Fourier transform derives from the Laplacian $L$ eigendecomposition:\n",
    "\n",
    "$$\n",
    "L = U D U^T\n",
    "$$\n",
    "\n",
    "where $U$ contains (orthonormal) eigenvectors $u_i$ and $D$ is a diagonal matrix containing the eigenvalues.\n",
    "\n",
    "For a graph signal $f$, the associated Fourier transform $\\hat{f}$ is given by:\n",
    "\n",
    "$$\n",
    "\\hat{f}:=U^T f.\n",
    "$$\n",
    "\n",
    "To illustre this definition, we can compute the Fourier basis on the two graph examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the line graph, we compute and display the eigenvalues and the first eigenvectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Laplacian eigendecomposition\n",
    "eigenvals_line, eigenvects_line = eigh(line_graph.L.todense())\n",
    "\n",
    "fig, (ax_0, ax_1) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "ax_0.plot(range(1, eigenvals_line.size + 1), eigenvals_line, \"-*\")\n",
    "ax_0.set_xlabel(\"Eigenvalue number\")\n",
    "ax_0.set_ylabel(\"Eigenvalue\")\n",
    "ax_0.set_title(\"Line graph\")\n",
    "\n",
    "for k in range(5):\n",
    "    ax_1.plot(eigenvects_line[:, 2 * k], label=f\"Eigenvector {2*k+1}\")\n",
    "_ = plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>What do you observe on the shape of the eigenvectors?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>For the grid graph, compute and display the first and last eigenvectors.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Visually, which one the two eigenvectors is the smoothest?</p>\n",
    "    <p>Recall the definition of a graph signal smoothness.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Compute and plot the smoothness of each eigenvector of the Laplacian (of the grid graph).</p>\n",
    "    <p>What do you observe? Is this expected?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fourier transform\n",
    "\n",
    "Using the Fourier basis, we can now compute the Fourier transform of each signal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax_0, ax_1) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "\n",
    "# Fourier transform\n",
    "signal_line_fourier = eigenvects_line.T.dot(signal_line)\n",
    "signal_grid_fourier = eigenvects_grid.T.dot(signal_grid)\n",
    "\n",
    "\n",
    "# plot\n",
    "ax_0.plot(abs(signal_line_fourier), \"*-\")\n",
    "ax_0.set_title(\"Fourier transform (signal on the line graph)\")\n",
    "\n",
    "ax_1.plot(abs(signal_grid_fourier), \"*-\")\n",
    "_ = ax_1.set_title(\"Fourier transform (signal on the grid graph)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Given the Fourier representation of a signal, how can we recover the original signal?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since there is a frequency representation, we can filter the signals, as in the classical signal processing setting. For instance, let us set to 0 all Fourier coefficients above a certain cut-off frequency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we keep only 20% of the Fourier coefficients\n",
    "\n",
    "# filtering the line graph signal\n",
    "fourier_transform_filtered = np.zeros(signal_line_fourier.size)\n",
    "fourier_transform_filtered[: signal_line_fourier.size // 5] = signal_line_fourier[\n",
    "    : signal_line_fourier.size // 5\n",
    "]\n",
    "signal_line_filtered = eigenvects_line.dot(fourier_transform_filtered)\n",
    "# plot\n",
    "fig, (ax_0, ax_1) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "ax_0.plot(abs(signal_line_fourier), \"*-\", label=\"Original Fourier transform\")\n",
    "ax_0.plot(\n",
    "    abs(fourier_transform_filtered),\n",
    "    \"*-\",\n",
    "    label=\"Filtered Fourier transform\",\n",
    "    alpha=0.5,\n",
    ")\n",
    "ax_0.set_title(\"Fourier transform (signal on the line graph)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filtering the grid graph signal\n",
    "fourier_transform_filtered = np.zeros(signal_grid_fourier.size)\n",
    "fourier_transform_filtered[: signal_grid_fourier.size // 5] = signal_grid_fourier[\n",
    "    : signal_grid_fourier.size // 5\n",
    "]\n",
    "signal_grid_filtered = eigenvects_grid.dot(fourier_transform_filtered)\n",
    "# plot\n",
    "ax_1.plot(abs(signal_grid_fourier), \"*-\", label=\"Original Fourier transform\")\n",
    "ax_1.plot(\n",
    "    abs(fourier_transform_filtered),\n",
    "    \"*-\",\n",
    "    label=\"Filtered Fourier transform\",\n",
    "    alpha=0.5,\n",
    ")\n",
    "_ = ax_1.set_title(\"Fourier transform (signal on the grid graph)\")\n",
    "\n",
    "_ = plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then reconstruct the filtered Fourier transform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, (ax_0, ax_1) = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))\n",
    "\n",
    "line_graph.plot_signal(signal_line_filtered, ax=ax_0)\n",
    "ax_0.set_title(\"Reconstruction\")\n",
    "grid_graph.plot_signal(signal_grid_filtered, ax=ax_1)\n",
    "_ = ax_1.set_title(\"Reconstruction\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = fig_ax()\n",
    "ax.plot(signal_line, label=\"Original signal\")\n",
    "ax.plot(signal_line_filtered, label=\"Filtered signal\")\n",
    "_ = plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use this procedure to compress signals that are supported on arbitrary graphs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiments\n",
    "\n",
    "Let us illustrate a few GSP methods on a real-world data set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_df, stations_df, description = load_molene_meteo_dataset()\n",
    "print(description)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# only keep a subset of stations\n",
    "keep_cond = stations_df.Nom.isin(STATION_LIST)\n",
    "stations_df = stations_df[keep_cond]\n",
    "keep_cond = data_df.station_name.isin(STATION_LIST)\n",
    "data_df = data_df[keep_cond].reset_index().drop(\"index\", axis=\"columns\")\n",
    "\n",
    "# convert temperature from Kelvin to Celsius\n",
    "data_df[\"temp\"] = data_df.t - 273.15  # temperature in Celsius\n",
    "\n",
    "# convert pandas df to geopandas df\n",
    "stations_gdf = geopandas.GeoDataFrame(\n",
    "    stations_df,\n",
    "    geometry=geopandas.points_from_xy(stations_df.Longitude, stations_df.Latitude),\n",
    ").set_crs(CRS)\n",
    "\n",
    "stations_gdf.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pivot the table. We now have a multivariate time serie."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature_df = data_df.pivot(index=\"date\", values=\"temp\", columns=\"station_name\")\n",
    "temperature_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot the position of the grounds stations on a map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = stations_gdf.geometry.plot(figsize=(10, 10))\n",
    "cx.add_basemap(ax, crs=stations_gdf.crs.to_string(), zoom=8)\n",
    "ax.set_axis_off()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data exploration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can start by checking for some malfunctions in the stations. To that end, we simply count the number of NaNs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature_df.isna().sum(axis=0).sort_values(ascending=False).head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After this, we can look at the (geodesic) distance between stations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stations_np = stations_df[[\"Longitude\", \"Latitude\"]].to_numpy()\n",
    "dist_mat_condensed = pdist(stations_np, metric=get_geodesic_distance)\n",
    "dist_mat_square = squareform(dist_mat_condensed)\n",
    "\n",
    "fig, ax = fig_ax((10, 8))\n",
    "_ = sns.heatmap(\n",
    "    dist_mat_square,\n",
    "    xticklabels=stations_df.Nom,\n",
    "    yticklabels=stations_df.Nom,\n",
    "    ax=ax,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>What are the two closest stations?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>What are the two most distant stations?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can plot the temperature evolution for the two closest stations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Plot the temperature evolution for the two closest stations.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Do the same for the two most distant stations.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph construction\n",
    "\n",
    "This network of sensors can be modeled as a graph, and the temperature signal, as a serie of graph signals.\n",
    "\n",
    "To that end, we need to define a graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> Give two ways to derive an adjacency matrix from a distance matrix?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> The number of connected components of a graph is the multiplicity of the eigenvalue 0 of the Laplacian matrix. Can you explain why? Write a function `is_connected` which returns True if the input graph is connected and False otherwise.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_connected(graph) -> bool:\n",
    "    ..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Distance-based weigthless graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold = 1  # km\n",
    "adjacency_matrix = squareform((dist_mat_condensed < threshold).astype(int))\n",
    "G = graphs.Graph(adjacency_matrix)\n",
    "print(\n",
    "    f\"The graph is {'not ' if not is_connected(G) else ''}connected, with {G.N} nodes, {G.Ne} edges\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = stations_gdf.geometry.plot(figsize=(10, 10))\n",
    "cx.add_basemap(ax, crs=stations_gdf.crs.to_string(), zoom=8)\n",
    "ax.set_axis_off()\n",
    "G.set_coordinates(stations_np)\n",
    "G.plot(ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> Plot the number of edges with respect to the threshold.</p>\n",
    "    <p>What is approximatively the lowest threshold possible in order to have a connected graph?</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> Plot the average degree with respect to the threshold.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Distance-based weighted graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each vertex can be connected to other vertices by edges weighted by a Gaussian kernel:\n",
    "$$\n",
    "W_{ij} = \\exp\\left(-\\frac{\\|c_i-c_j\\|^2}{\\sigma^2}\\right) \\quad\\text{if}\\quad \\exp\\left(-\\frac{\\|c_i-c_j\\|^2}{\\sigma^2}\\right)>\\lambda,\\ 0\\ \\text{otherwise}\n",
    "$$\n",
    "where the $c_i$ are the station positions, $\\sigma$ is the so-called bandwidth parameter, and $\\lambda>0$ is a threshold. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma = np.median(dist_mat_condensed)  # median heuristic\n",
    "threshold = 0.85\n",
    "adjacency_matrix_gaussian = squareform(\n",
    "    get_exponential_similarity(dist_mat_condensed, sigma, threshold)\n",
    ")\n",
    "G_gaussian = graphs.Graph(adjacency_matrix_gaussian)\n",
    "print(\n",
    "    f\"The graph is {'not ' if not is_connected(G_gaussian) else ''}connected, with {G_gaussian.N} nodes, {G_gaussian.Ne} edges\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = stations_gdf.geometry.plot(figsize=(10, 10))\n",
    "cx.add_basemap(ax, crs=stations_gdf.crs.to_string(), zoom=8)\n",
    "ax.set_axis_off()\n",
    "G_gaussian.set_coordinates(stations_np)\n",
    "G_gaussian.plot(ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> What is the influence of the threshold.</p>\n",
    "    <p>Choose an appropriate value for the threshold.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Correlation graph\n",
    "\n",
    "The correlation between the signals can also define a graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = sns.heatmap(temperature_df.corr())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> Describe how to create a graph using the signal correlation. Code it.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ax = stations_gdf.geometry.plot(figsize=(10, 10))\n",
    "cx.add_basemap(ax, crs=stations_gdf.crs.to_string(), zoom=8)\n",
    "ax.set_axis_off()\n",
    "G_corr.set_coordinates(stations_np)\n",
    "G_corr.plot(ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that stations that are very far apart can be connected. This unveils a different structure."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph signal processing\n",
    "\n",
    "In this section, we set the graph to the one that comes from the Gaussian kernel."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Signal smoothness\n",
    "\n",
    "Let us study the signal smoothness, at each hour of measure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop the NaNs\n",
    "temperature_df_no_nan = temperature_df.dropna(axis=0, how=\"any\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Compute the smoothness for a specific hour of measure.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose a specific hour\n",
    "choosen_hour = pd.to_datetime(\"2014-01-01 01:00:00\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Display the smoothess evolution with time.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This displays interesting patterns."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p> Show the state of the network, when the signal is the smoothest.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-success\" role=\"alert\">\n",
    "    <p><b>Question</b></p>\n",
    "    <p>Do the same, when the signal is the least smooth.</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tutorial-mva-2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  },
  "toc-showcode": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}