{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Wm0emdxJJADg" }, "source": [ "# A Gentle Introduction to tsl\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "K5RXhBhhMI_a" }, "source": [ "In this tutorial notebook, we will see how to train our custom-made **Spatiotemporal Graph Neural Network (STGNN) for traffic forecasting** using [**tsl (Torch Spatiotemporal)**](https://torch-spatiotemporal.readthedocs.io/), a Python library for **neural spatiotemporal data processing**, with a focus on Graph Neural Networks.\n", "\n", "It is built upon the most used libraries of the **python scientific computing ecosystem**, with the final objective of providing a straightforward process that goes from data preprocessing to model prototyping.\n", "\n", "In particular, tsl offers a wide range of utilities to develop neural networks in [PyTorch](https://pytorch.org/) and [PyTorch Geometric (PyG)](https://www.pyg.org/) for processing **spatiotemporal graph signals**.\n", "\n", "\"Open\n", "\"Download" ] }, { "cell_type": "markdown", "metadata": { "id": "7YtTUCcaI6cK" }, "source": [ "
\n", "\n", "## Quickstart\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "IwgAw9bkMI_e" }, "source": [ "### Installation\n", "
\n", "\n", "Let's start by installing tsl from source and the related dependencies. Installing tsl from GitHub ensures to be up-to-date with the latest version." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YKEZTeotmGIP" }, "outputs": [], "source": [ "# Install required packages.\n", "import os\n", "import torch\n", "os.environ['TORCH'] = torch.__version__\n", "print(torch.__version__)\n", "\n", "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n", "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n", "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git\n", "!pip install -q git+https://github.com/TorchSpatiotemporal/tsl.git" ] }, { "cell_type": "markdown", "metadata": { "id": "ZxoqCutTYVk4" }, "source": [ "We refer to [tsl](https://torch-spatiotemporal.readthedocs.io/en/latest/usage/quickstart.html) and [PyG](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) installation guidelines for the setup in other environments.\n", "\n", "Let's check if everything is ok." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hLAxpbejYVk4" }, "outputs": [], "source": [ "import tsl\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "\n", "print(f\"tsl version : {tsl.__version__}\")\n", "print(f\"torch version: {torch.__version__}\")\n", "\n", "pd.options.display.float_format = '{:.2f}'.format\n", "np.set_printoptions(edgeitems=3, precision=3)\n", "torch.set_printoptions(edgeitems=2, precision=3)\n", "\n", "# Utility functions ################\n", "def print_matrix(matrix):\n", " return pd.DataFrame(matrix)\n", "\n", "def print_model_size(model):\n", " tot = sum([p.numel() for p in model.parameters() if p.requires_grad])\n", " out = f\"Number of model ({model.__class__.__name__}) parameters:{tot:10d}\"\n", " print(\"=\" * len(out))\n", " print(out)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_bK40BaYVk5" }, "source": [ "### Usage\n", "
\n", "\n", "tsl is more than a collection of layers. We can classify the library modules into:\n", "\n", "* **Data loading modules**
\n", " Manage how to store, load, and preprocess spatiotemporal data, providing a simple interface to make data ready for downstream neural models.\n", "\n", "* **Inference modules**
\n", " Models and engines that take as input spatiotemporal data to make inferences for the task at hand, e.g., forecasting or imputation.\n", "\n", "We will go deeper on them in next sections." ] }, { "cell_type": "markdown", "metadata": { "id": "cF7JZwlSIBIA" }, "source": [ "
\n", "\n", "## Loading and Preprocessing Data\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "_Wy4OORcYVk5" }, "source": [ "### Loading a tabular dataset\n", "
\n", "\n", "`tsl` comes with several datasets used in spatiotemporal processing literature. You can find them inside the submodule [`tsl.datasets`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/datasets.html).\n", "\n", "As an example, we start by using the [MetrLA](https://paperswithcode.com/sota/traffic-prediction-on-metr-la) dataset, a common benchmark for traffic forecasting. The dataset contains traffic readings collected from 207 loop detectors on highways in Los Angeles County, aggregated in 5 minute intervals for 4 months between March 2012 to June 2012. Loading the dataset is as simple as that:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BsPkG4uMYVk5" }, "outputs": [], "source": [ "from tsl.datasets import MetrLA\n", "\n", "dataset = MetrLA(root='./data')\n", "\n", "print(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "67IyIupLYVk6" }, "source": [ "All the datasets in tsl are subclass of the root class [`tsl.datasets.Dataset`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/datasets_prototypes.html#tsl.datasets.prototypes.Dataset), exposing useful APIs for spatiotemporal datasets. We can see that data are organized a 3-dimensional array, with:\n", "\n", "* **34.272** time steps (1 each 5 minute for 4 months)\n", "* **207** nodes (the loop detectors)\n", "* **1** channels (detected speed)\n", "\n", "Nice! Other than storing the data of interest, the dataset comes with useful tools." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L74iikTiYVk6" }, "outputs": [], "source": [ "print(f\"Sampling period: {dataset.freq}\")\n", "print(f\"Has missing values: {dataset.has_mask}\")\n", "print(f\"Percentage of missing values: {(1 - dataset.mask.mean()) * 100:.2f}%\")\n", "print(f\"Has exogenous variables: {dataset.has_covariates}\")\n", "print(f\"Covariates: {', '.join(dataset.covariates.keys())}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "TC8LfuK4YVk6" }, "source": [ "Let's look at the output. We know that the dataset has missing entries, with `dataset.mask` being a binary indicator associated with each timestep, node and channel (with ones indicating valid values).\n", "\n", "Also, the dataset has a **covariate** attribute (i.e., exogenous variables) – the distance matrix – containing the pairwise distances between sensors.\n", "\n", "You can access covariates by `dataset.{covariate_name}`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_YgmDUtWYVk7" }, "outputs": [], "source": [ "print_matrix(dataset.dist)" ] }, { "cell_type": "markdown", "metadata": { "id": "u1uMQREKYVk7" }, "source": [ "This matrix stores the pairwise distance between sensors, with `inf` denoting two non-neighboring sensors.\n", "\n", "Let's now check how the speed readings look like." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hxqyccSSYVk7" }, "outputs": [], "source": [ "dataset.dataframe()" ] }, { "cell_type": "markdown", "metadata": { "id": "PN0pLP1IYVk8" }, "source": [ "#### Connecting sensors\n", "\n", "Besides the time series, to properly use graph-based models, we need to __connect__ nodes somehow.\n", "\n", "With the method [`dataset.get_similarity()`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/datasets_prototypes.html#tsl.datasets.prototypes.Dataset.get_similarity) we can retrieve nodes' similarities computed with different methods. The available similarity methods for a dataset can be found at `dataset.similarity_options`, while the default one is at `dataset.similarity_score`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NRZHPBhmYVk8" }, "outputs": [], "source": [ "print(f\"Default similarity: {dataset.similarity_score}\")\n", "print(f\"Available similarity options: {dataset.similarity_options}\")\n", "print(\"==========================================\")\n", "\n", "sim = dataset.get_similarity(\"distance\") # or dataset.compute_similarity()\n", "\n", "print(\"Similarity matrix W:\")\n", "print_matrix(sim)" ] }, { "cell_type": "markdown", "metadata": { "id": "vCcgblQAYVk8" }, "source": [ "With this method, we compute weight $w_t^{i,j}$ of the edge connecting $i$-th and $j$-th node as
\n", "$$\n", "w^{i,j} = \\left\\{\\begin{array}{cl}\n", " \\exp \\left(-\\frac{\\operatorname{dist}\\left(i, j\\right)^{2}}{\\gamma}\\right) & \\operatorname{dist}\\left(i, j\\right) \\leq \\delta \\\\\n", " 0 & \\text{otherwise}\n", "\\end{array}\\right. ,\n", "$$
\n", "where $\\operatorname{dist}\\left(i, j\\right)$ is the distance between $i$-th and $j$-th node, $\\gamma$ controls the kernel width and $\\delta$ is a threshold. Notice that in this case the similarity matrix is not symmetric, since the original preprocessed distance matrix is not symmetric too.\n", "\n", "So far so good, now we can build an adjacency matrix out ouf the computed similarity.\n", "\n", "The method [`dataset.get_connectivity()`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/datasets_prototypes.html#tsl.datasets.prototypes.Dataset.get_connectivity) – calling `dataset.get_similarity()` under-the-hood – provides useful preprocessing options, and, eventually, returns a possibly sparse, possibly weighted, adjacency matrix." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "goTBHPNGYVk8" }, "outputs": [], "source": [ "connectivity = dataset.get_connectivity(threshold=0.1,\n", " include_self=False,\n", " normalize_axis=1,\n", " layout=\"edge_index\")" ] }, { "cell_type": "markdown", "metadata": { "id": "TsPSdKmRYVk9" }, "source": [ "Let's see what happens with this function call:\n", "\n", "1. compute the similarity matrix as before;\n", "1. set to 0 values **below** 0.1 (`threshold=0.1`);\n", "1. **remove** self loops (`include_self=False`);\n", "1. **normalize** edge weights by the **in degree** of nodes (`normalize_axis=1`);\n", "1. request the sparse **COO layout** of PyG (`layout=\"edge_index\"`)\n", "\n", "The connectivity matrix with `edge_index` layout is provided in COO format, adopting the convention and notation used in PyTorch Geometric. The returned connectivity is a tuple (`edge_index`, `edge_weight`), where `edge_index` lists all edges as pairs of source-target nodes (dimensions `[2, E]`) and `edge_weight` (dimension `[E]`) stores the corresponding weights." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jD8S74swYVk9" }, "outputs": [], "source": [ "edge_index, edge_weight = connectivity\n", "\n", "print(f'edge_index {edge_index.shape}:\\n', edge_index)\n", "print(f'edge_weight {edge_weight.shape}:\\n', edge_weight)" ] }, { "cell_type": "markdown", "metadata": { "id": "KaGwFbs4YVk9" }, "source": [ "The `\"dense\"` layout instead corresponds to the weighted adjacency matrix $A \\in \\mathbb{R}^{N \\times N}$. The module [`tsl.ops.connectivity`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/ops.html#module-tsl.ops.connectivity) contains useful operations for connectivities, including methods to change layout." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gugmaNIpYVk9" }, "outputs": [], "source": [ "from tsl.ops.connectivity import edge_index_to_adj\n", "\n", "adj = edge_index_to_adj(edge_index, edge_weight)\n", "print(f'A {adj.shape}:')\n", "print_matrix(adj)" ] }, { "cell_type": "markdown", "metadata": { "id": "aQJMPegnH8vz" }, "source": [ "From the dense layout, the sparse COO format can be easily retrieved as:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P_n5cO8-H7h8" }, "outputs": [], "source": [ "print(f'Sparse edge weights:\\n', adj[edge_index[1], edge_index[0]])" ] }, { "cell_type": "markdown", "metadata": { "id": "3ZB1SiBYYVk-" }, "source": [ "### Building a PyTorch-ready dataset\n", "
\n", "\n", "In this section, we will see how to fetch **spatiotemporal graph signals** that are then given as input to a neural network (e.g., an STGNN) starting from a dataset of this kind.\n", "\n", "The first class that comes in help is [`tsl.data.SpatioTemporalDataset`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_pytorch_datasets.html#tsl.data.SpatioTemporalDataset). This class is a subclass of [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) and is in charge of mapping a tabular dataset represented in your preferred format (e.g., numpy array, pandas dataframe or the aforementioned [`tsl.datasets.Dataset`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/datasets_prototypes.html#tsl.datasets.prototypes.Dataset)) to a PyTorch-ready implementation.\n", "\n", "In particular, a `SpatioTemporalDataset` object can be used to achieve the following:\n", "* Perform **data manipulation** operations required to feed the data to a PyTorch module (e.g., casting data to `torch.tensor`, handling possibly different `shapes`, synchronizing temporal data).\n", "* Create **`(input, target)` samples** for supervised learning following the [**sliding window**](https://torch-spatiotemporal.readthedocs.io/en/latest/usage/spatiotemporal_dataset.html#sliding-window) approach.\n", "* Define how data should be **arranged** in a **spatiotemporal graph signal** (e.g., which are the inputs and targets, how node attributes and covariates variables are mapped into a single graph).\n", "* **Preprocess** data before creating a **spatiotemporal graph signal** by appling [**transformations**](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/transforms.html) or [**scaling**](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_preprocessing.html) operations.\n", "\n", "Let's see how to go from a `Dataset` to a `SpatioTemporalDataset`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s2VPMgQnYVk-" }, "outputs": [], "source": [ "from tsl.data import SpatioTemporalDataset\n", "\n", "torch_dataset = SpatioTemporalDataset(target=dataset.dataframe(),\n", " connectivity=connectivity,\n", " mask=dataset.mask,\n", " horizon=12,\n", " window=12,\n", " stride=1)\n", "print(torch_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "7Ap77CYPYVk-" }, "source": [ "As you can see, the number of samples is not the same as the number of steps we have in the dataset. Indeed, we divided the historic time series with a **sliding window** of **12 time steps** for the **lockback window** (`window=12`), with a corresponding **horizon** of **12 time steps** (`horizon=12`). Thus, a single sample spans for a total of 24 time steps. The `stride` parameters set how many time steps intercurring between two subsequent samples. The following picture helps at visualizing how these (and more) parameters affect the slicing of the original time series in samples.\n", "\n", "

\n", " \n", "

" ] }, { "cell_type": "markdown", "metadata": { "id": "vjHwpRMYBCbt" }, "source": [ "### Spatiotemporal graph signals in tsl\n", "\n", "We now have a PyTorch-based dataset containing a collection of spatiotemporal graph signals. We can fetch samples in the same way we fetch elements of a Python list. Let's look in details to the layout of a sample:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_G7-8HMcYVk_" }, "outputs": [], "source": [ "sample = torch_dataset[0]\n", "print(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "RXSGWvK5YVk_" }, "source": [ "A sample is of type [`tsl.data.Data`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_objects.html#tsl.data.Data), the base class for representing spatiotemporal graph signals in tsl. This class extends [`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data), preserving all its functionalities and\n", "adding utilities for spatiotemporal data processing. The main APIs of `Data` include:\n", "\n", "* **`Data.input`**: view on the tensors stored in `Data` that are meant to serve as input to the model.\n", " In the simplest case of a single node-attribute matrix, we could just have `Data.input.x`.\n", "* **`Data.target`**: view on the tensors stored in `Data` used as labels to train the model.\n", " In the common case of a single label, we could just have `Data.input.y`.\n", "* **`Data.edge_index`**: graph connectivity in COO format (i.e., as node pairs).\n", "* **`Data.edge_weight`**: weights of the graph connectivity, if any.\n", "* **`Data.mask`**: binary mask indicating the data in `Data.target.y` to be used as ground-truth for the loss (default is `None`).\n", "* **`Data.transform`**: mapping of `ScalerModule`, whose keys must be\n", " transformable (or transformed) tensors in `Data`.\n", "* **`Data.pattern`**: mapping containing the *pattern* for each tensor in `Data`. Patterns add information about the dimensions of tensors (e.g., specifying which are the time step and node dimensions).\n", "\n", "None of these attributes are required and custom attributes can be seamlessly added.\n", "\n", "Let's check more in details how each of these attributes is composed." ] }, { "cell_type": "markdown", "metadata": { "id": "7tODoWS_YVk_" }, "source": [ "#### Input and Target\n", "\n", "`Data.input` and `Data.target` provide a **view** on the unique (shared) storage in `Data`, such that the same key in `Data.input` and `Data.target` cannot reference different objects." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TbB9HvSoYVk_" }, "outputs": [], "source": [ "sample.input.to_dict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HdorD4uLYVk_" }, "outputs": [], "source": [ "sample.target.to_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "SNq8M-yaYVk_" }, "source": [ "#### Mask and Transform\n", "\n", "`mask` and `transform` are just symbolic links to the corresponding object inside the storage. They also expose properties `has_mask` and `has_transform`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rAlbTjo7YVlA" }, "outputs": [], "source": [ "if sample.has_mask:\n", " print(sample.mask)\n", "else:\n", " print(\"Sample has no mask.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zW6tx42LYVlA" }, "outputs": [], "source": [ "if sample.has_transform:\n", " print(sample.transform)\n", "else:\n", " print(\"Sample has no transformation functions.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LR-LrjOwYVlA" }, "source": [ "#### Pattern\n", "\n", "The `pattern` mapping can be useful to glimpse on how data are arranged.\n", "The convention we use is the following:\n", "\n", "* `'t'` stands for the **time steps** dimension\n", "* `'n'` stands for a **node** dimension\n", "* `'e'` stands for the **edge** dimension\n", "* `'f'` stands for a **feature** dimension\n", "* `'b'` stands for the **batch** dimension\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4KeKgewEYVlA" }, "outputs": [], "source": [ "print(sample.pattern)\n", "print(\"================== Or we can print patterns and shapes together ==================\")\n", "print(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "6X47QXG6Vsgx" }, "source": [ "### Batching spatiotemporal graph signals\n", "\n", "Getting a batch of spatiotemporal graph signals from a single dataset is as simple as accessing multiple elements from a list:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WMYwon6dV9xt" }, "outputs": [], "source": [ "batch = torch_dataset[:5]\n", "print(batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "f4ba4122WKi1" }, "source": [ "As you can see, we now have an additional dimension for the time-varying elements (i.e., `x` and `y`) denoted by pattern `b`, i.e., the batch dimension. In this new, first dimension we stacked the features of the first 5 spatiotemporal graphs in the dataset.\n", "\n", "Note that this is possible only because we are assuming a fixed underlying topology, as also confirmed by the `edge_index` and `edge_weight` attributes. The explanation on how `Data` objects with different graphs are batched together is out of the scope of this notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "9j-pA8lCYVlA" }, "source": [ "### Preparing the dataset for training\n", "
\n", "\n", "Usually, before running an experiment there are two quite common preprocessing steps:\n", "\n", "* **splitting** the dataset into **training/validation/test** sets;\n", "* **data preprocessing** (scaling/normalizing data, detrending).\n", "\n", "In tsl, these operations are managed by the [`tsl.data.SpatioTemporalDataModule`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_datamodule.html#tsl.data.datamodule.SpatioTemporalDataModule), which is based on the [`LightningDataModule`](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningDataModule.html#pytorch_lightning.core.LightningDataModule) from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). A `DataModule` allows us to standardize and make consistent the training, validation, test splits, data preparation and transformations across different environments and experiments.\n", "\n", "Let's see an example" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3aGnkoCgYVlA" }, "outputs": [], "source": [ "from tsl.data.datamodule import (SpatioTemporalDataModule,\n", " TemporalSplitter)\n", "from tsl.data.preprocessing import StandardScaler\n", "\n", "# Normalize data using mean and std computed over time and node dimensions\n", "scalers = {'target': StandardScaler(axis=(0, 1))}\n", "\n", "# Split data sequentially:\n", "# |------------ dataset -----------|\n", "# |--- train ---|- val -|-- test --|\n", "splitter = TemporalSplitter(val_len=0.1, test_len=0.2)\n", "\n", "dm = SpatioTemporalDataModule(\n", " dataset=torch_dataset,\n", " scalers=scalers,\n", " splitter=splitter,\n", " batch_size=64,\n", ")\n", "\n", "print(dm)" ] }, { "cell_type": "markdown", "metadata": { "id": "229wC1LaYVlB" }, "source": [ "You can consider to extend the base `SpatioTemporalDataModule` to add further processing to fit your needs.\n", "\n", "At this point, the `DataModule` object has not actually performed any processing yet (lazy approach).\n", "\n", "We can execute the preprocessing routines by calling the `dm.setup()` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tMpX1zbfYVlB" }, "outputs": [], "source": [ "dm.setup()\n", "print(dm)" ] }, { "cell_type": "markdown", "metadata": { "id": "wXQy8qMDYVlB" }, "source": [ "During `setup` the datamodule does the following operations:\n", "\n", "1. Carries out the dataset splitting into training/validation/test sets according to the provided [`Splitter`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_datamodule.html#tsl.data.datamodule.splitters.Splitter).\n", "1. Fits all the `Scalers` on the training data in `torch_dataset` corresponding to the scalers' key." ] }, { "cell_type": "markdown", "metadata": { "id": "iW-Kns-7Rgqn" }, "source": [ "#### Splitters\n", "\n", "Splitters in tsl are the objects defining the policy of data splitting. [Read more](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_datamodule.html#splitters)" ] }, { "cell_type": "markdown", "metadata": { "id": "lJHYk_TpRkdb" }, "source": [ "#### Scalers\n", "\n", "The `tsl.data.preprocessing` package offers several of the most common data normalization techniques under the `tsl.data.preprocessing.Scaler` interface.\n", "They adopt an API similar to `scikit-learn`'s scalers, with `fit`/`transform`/`fit_transform`/`inverse_transform` methods. [Read more](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_preprocessing.html#scalers)" ] }, { "cell_type": "markdown", "metadata": { "id": "VKqiwe4lLTBo" }, "source": [ "
\n", "\n", "## Building Spatiotemporal Graph Neural Networks\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "d60HdqXCYVlB" }, "source": [ "In this section, we will see how to build a very simple Spatiotemporal Graph Neural Network.\n", "\n", "All the functions and classes needed to build neural networks in tsl are under the [`tsl.nn`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn.html) module.\n", "\n", "\n", "### The `nn` module\n", "
\n", "\n", "The [`tsl.nn`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn.html) module is organized as follows:\n", "\n", "```\n", "tsl\n", "└── nn\n", "    ├── base\n", "    ├── blocks\n", "    ├── layers\n", "    └── models\n", "```\n", "\n", "The 3 most important submodules in it are [`layers`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_layers.html), [`blocks`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_blocks.html), and [`models`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_models.html), ordered by increasing level of abstraction.\n", "\n", "#### Layers\n", "\n", "A **layer** is a basic building block for our neural networks. In simple words, a layer takes an input, performs one (or few) operations, and returns a transformation of the input. Examples of layers are [`DiffConv`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_layers.html#tsl.nn.layers.graph_convs.DiffConv), which implements the [diffusion convolution](https://arxiv.org/abs/1707.01926) operation, or [`LayerNorm`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_layers.html#tsl.nn.layers.norm.LayerNorm).\n", "\n", "#### Blocks\n", "\n", "**blocks** perform more complex transformations or combine several operations. We divide blocks into **encoders**, if they provide a representation of the input in a new space, and **decoders**, if they produce a meaningful output from a representation.\n", "\n", "#### Models\n", "\n", "We wrap a series of operations, represented by blocks and/or layers, in a **model**. A model takes as input a spatiotemporal graph signal and returns the desired output, e.g., the forecasted node features at future time steps." ] }, { "cell_type": "markdown", "metadata": { "id": "aQiMvBjGYVlC" }, "source": [ "### Designing a custom STGNN\n", "
\n", "\n", "Let's get the hands dirty and create our first simple STGNN! We will follow the **Time-then-Space** paradigm. We use a GRU shared among the nodes to process the temporal dimension. This will give us in output a single feature vector for each node, which is then propagated through the underlying graph using a Diffusion Convolutional GNN. Before and after, we add linear transformations to encode the input features and decode the learned representations. We also make use of **node embeddings** (free parameters learned individually for each node) to make our STGNN a **global-local model** ([Cini et al., 2023](https://arxiv.org/abs/2302.04071)).\n", "\n", "All the layers that we need are provided inside `tsl.nn`. We use:\n", "* [`RNN`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_blocks.html#tsl.nn.blocks.encoders.RNN) from `tsl.nn.blocks.encoders` for the GRU;\n", "* [`DiffConv`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/nn_layers.html#tsl.nn.layers.graph_convs.DiffConv) from `tsl.nn.layers.graph_convs` for the diffusion convolution;\n", "* `StaticGraphEmbedding` from `tsl.nn.base` for the node embeddings." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ORL_KKbuYVlC" }, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "from tsl.nn.blocks.encoders import RNN\n", "from tsl.nn.layers import NodeEmbedding, DiffConv\n", "from einops.layers.torch import Rearrange # reshape data with Einstein notation\n", "\n", "\n", "class TimeThenSpaceModel(nn.Module):\n", " def __init__(self, input_size: int, n_nodes: int, horizon: int,\n", " hidden_size: int = 32,\n", " rnn_layers: int = 1,\n", " gnn_kernel: int = 2):\n", " super(TimeThenSpaceModel, self).__init__()\n", "\n", " self.encoder = nn.Linear(input_size, hidden_size)\n", "\n", " self.node_embeddings = NodeEmbedding(n_nodes, hidden_size)\n", "\n", " self.time_nn = RNN(input_size=hidden_size,\n", " hidden_size=hidden_size,\n", " n_layers=rnn_layers,\n", " cell='gru',\n", " return_only_last_state=True)\n", " \n", " self.space_nn = DiffConv(in_channels=hidden_size,\n", " out_channels=hidden_size,\n", " k=gnn_kernel)\n", "\n", " self.decoder = nn.Linear(hidden_size, input_size * horizon)\n", " self.rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)\n", "\n", " def forward(self, x, edge_index, edge_weight):\n", " # x: [batch time nodes features]\n", " x_enc = self.encoder(x) # linear encoder: x_enc = xΘ + b\n", " x_emb = x_enc + self.node_embeddings() # add node-identifier embeddings\n", " h = self.time_nn(x_emb) # temporal processing: x=[b t n f] -> h=[b n f]\n", " z = self.space_nn(h, edge_index, edge_weight) # spatial processing\n", " x_out = self.decoder(z) # linear decoder: z=[b n f] -> x_out=[b n t⋅f]\n", " x_horizon = self.rearrange(x_out)\n", " return x_horizon" ] }, { "cell_type": "markdown", "metadata": { "id": "BM1Gp3m912ah" }, "source": [ "We can play with hyperparameters and make an instance of our model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gp2oLPBorWeP" }, "outputs": [], "source": [ "hidden_size = 32 #@param\n", "rnn_layers = 1 #@param\n", "gnn_kernel = 2 #@param\n", "\n", "input_size = torch_dataset.n_channels # 1 channel\n", "n_nodes = torch_dataset.n_nodes # 207 nodes\n", "horizon = torch_dataset.horizon # 12 time steps\n", "\n", "stgnn = TimeThenSpaceModel(input_size=input_size,\n", " n_nodes=n_nodes,\n", " horizon=horizon,\n", " hidden_size=hidden_size,\n", " rnn_layers=rnn_layers,\n", " gnn_kernel=gnn_kernel)\n", "print(stgnn)\n", "print_model_size(stgnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "6fu_1-ojYVlC" }, "source": [ "Fine, we loaded the data and built a model, so let's train it!" ] }, { "cell_type": "markdown", "metadata": { "id": "y3uABycjYVlC" }, "source": [ "### Setting up training\n", "
\n", "\n", "We are now ready to train our model. We set up the training procedure as we prefer, in the following we will use PyTorch Lightning's Trainer to reduce the burder of the dirty work. We recall that tsl is highly integrated with widely used PyTorch-based libraries, such as PyTorch Lightning and PyTorch Geometric." ] }, { "cell_type": "markdown", "metadata": { "id": "dxvEqv95YVlC" }, "source": [ "#### The Predictor\n", "\n", "In tsl, inference engines are implemented as a [`LightningModule`](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule). [`tsl.engines.Predictor`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/engines.html#tsl.engines.Predictor) is a base class that can be extended to build more complex forecasting approaches.\n", "These modules are meant to wrap deep models in order to ease training and inference phases." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GULyga_lYVlD" }, "outputs": [], "source": [ "from tsl.metrics.torch import MaskedMAE, MaskedMAPE\n", "from tsl.engines import Predictor\n", "\n", "loss_fn = MaskedMAE()\n", "\n", "metrics = {'mae': MaskedMAE(),\n", " 'mape': MaskedMAPE(),\n", " 'mae_at_15': MaskedMAE(at=2), # '2' indicates the third time step,\n", " # which correspond to 15 minutes ahead\n", " 'mae_at_30': MaskedMAE(at=5),\n", " 'mae_at_60': MaskedMAE(at=11)}\n", "\n", "# setup predictor\n", "predictor = Predictor(\n", " model=stgnn, # our initialized model\n", " optim_class=torch.optim.Adam, # specify optimizer to be used...\n", " optim_kwargs={'lr': 0.001}, # ...and parameters for its initialization\n", " loss_fn=loss_fn, # which loss function to be used\n", " metrics=metrics # metrics to be logged during train/val/test\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "-34Ntg3lYVlD" }, "source": [ "Now let's finalize the last details. We make use of [TensorBoard](https://www.tensorflow.org/tensorboard/) to log and visualize metrics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a3zdVgNRYVlD" }, "outputs": [], "source": [ "from pytorch_lightning.loggers import TensorBoardLogger\n", "\n", "logger = TensorBoardLogger(save_dir=\"logs\", name=\"tsl_intro\", version=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f4IU2vFKYVlD" }, "outputs": [], "source": [ "%load_ext tensorboard\n", "%tensorboard --logdir logs" ] }, { "cell_type": "markdown", "metadata": { "id": "d1b57xoTYVlD" }, "source": [ "We let [`pytorch_lightning.Trainer`](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) handle the dirty work for us. We can directly pass the datamodule to the trainer for fitting.\n", "\n", "If this is the case, the trainer will call the `setup` method, and then load train and validation sets." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WSJLVCaFYVlD" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", "\n", "checkpoint_callback = ModelCheckpoint(\n", " dirpath='logs',\n", " save_top_k=1,\n", " monitor='val_mae',\n", " mode='min',\n", ")\n", "\n", "trainer = pl.Trainer(max_epochs=100,\n", " logger=logger,\n", " gpus=1 if torch.cuda.is_available() else None,\n", " limit_train_batches=100, # end an epoch after 100 updates\n", " callbacks=[checkpoint_callback])\n", "\n", "trainer.fit(predictor, datamodule=dm)" ] }, { "cell_type": "markdown", "metadata": { "id": "1bAHFLAEYVlE" }, "source": [ "### Testing\n", "
\n", "\n", "\n", "Now let's see how the trained model behaves on new unseen data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aV-v6EfCYVlE" }, "outputs": [], "source": [ "predictor.load_model(checkpoint_callback.best_model_path)\n", "predictor.freeze()\n", "\n", "trainer.test(predictor, datamodule=dm);" ] }, { "cell_type": "markdown", "metadata": { "id": "QRd_FMpyYVlE" }, "source": [ "Cool! We succeeded in creating our first simple – yet effective – Spatiotemporal GNN!\n", "\n", "🥷 We are now **tsl ninjas**. We learned how to:\n", "\n", "* Load benchmark datasets\n", "* Organize data for processing\n", "* Preprocess the data\n", "* Build a Spatiotemporal GNN\n", "* Train and evaluate models\n", "\n", "We hope you enjoyed this introduction to tsl, now go and build the game-changer STGNN out there 🌎\n", "\n", "🧡 The tsl team" ] } ], "metadata": { "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 1 }