{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Wm0emdxJJADg" }, "source": [ "# A Gentle Introduction to tsl\n", "\n", "
\n", " \n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "vjHwpRMYBCbt" }, "source": [ "### Spatiotemporal graph signals in tsl\n", "\n", "We now have a PyTorch-based dataset containing a collection of spatiotemporal graph signals. We can fetch samples in the same way we fetch elements of a Python list. Let's look in details to the layout of a sample:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_G7-8HMcYVk_" }, "outputs": [], "source": [ "sample = torch_dataset[0]\n", "print(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "RXSGWvK5YVk_" }, "source": [ "A sample is of type [`tsl.data.Data`](https://torch-spatiotemporal.readthedocs.io/en/latest/modules/data_objects.html#tsl.data.Data), the base class for representing spatiotemporal graph signals in tsl. This class extends [`torch_geometric.data.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data), preserving all its functionalities and\n", "adding utilities for spatiotemporal data processing. The main APIs of `Data` include:\n", "\n", "* **`Data.input`**: view on the tensors stored in `Data` that are meant to serve as input to the model.\n", " In the simplest case of a single node-attribute matrix, we could just have `Data.input.x`.\n", "* **`Data.target`**: view on the tensors stored in `Data` used as labels to train the model.\n", " In the common case of a single label, we could just have `Data.input.y`.\n", "* **`Data.edge_index`**: graph connectivity in COO format (i.e., as node pairs).\n", "* **`Data.edge_weight`**: weights of the graph connectivity, if any.\n", "* **`Data.mask`**: binary mask indicating the data in `Data.target.y` to be used as ground-truth for the loss (default is `None`).\n", "* **`Data.transform`**: mapping of `ScalerModule`, whose keys must be\n", " transformable (or transformed) tensors in `Data`.\n", "* **`Data.pattern`**: mapping containing the *pattern* for each tensor in `Data`. Patterns add information about the dimensions of tensors (e.g., specifying which are the time step and node dimensions).\n", "\n", "None of these attributes are required and custom attributes can be seamlessly added.\n", "\n", "Let's check more in details how each of these attributes is composed." ] }, { "cell_type": "markdown", "metadata": { "id": "7tODoWS_YVk_" }, "source": [ "#### Input and Target\n", "\n", "`Data.input` and `Data.target` provide a **view** on the unique (shared) storage in `Data`, such that the same key in `Data.input` and `Data.target` cannot reference different objects." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TbB9HvSoYVk_" }, "outputs": [], "source": [ "sample.input.to_dict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HdorD4uLYVk_" }, "outputs": [], "source": [ "sample.target.to_dict()" ] }, { "cell_type": "markdown", "metadata": { "id": "SNq8M-yaYVk_" }, "source": [ "#### Mask and Transform\n", "\n", "`mask` and `transform` are just symbolic links to the corresponding object inside the storage. They also expose properties `has_mask` and `has_transform`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rAlbTjo7YVlA" }, "outputs": [], "source": [ "if sample.has_mask:\n", " print(sample.mask)\n", "else:\n", " print(\"Sample has no mask.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zW6tx42LYVlA" }, "outputs": [], "source": [ "if sample.has_transform:\n", " print(sample.transform)\n", "else:\n", " print(\"Sample has no transformation functions.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LR-LrjOwYVlA" }, "source": [ "#### Pattern\n", "\n", "The `pattern` mapping can be useful to glimpse on how data are arranged.\n", "The convention we use is the following:\n", "\n", "* `'t'` stands for the **time steps** dimension\n", "* `'n'` stands for a **node** dimension\n", "* `'e'` stands for the **edge** dimension\n", "* `'f'` stands for a **feature** dimension\n", "* `'b'` stands for the **batch** dimension\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4KeKgewEYVlA" }, "outputs": [], "source": [ "print(sample.pattern)\n", "print(\"================== Or we can print patterns and shapes together ==================\")\n", "print(sample)" ] }, { "cell_type": "markdown", "metadata": { "id": "6X47QXG6Vsgx" }, "source": [ "### Batching spatiotemporal graph signals\n", "\n", "Getting a batch of spatiotemporal graph signals from a single dataset is as simple as accessing multiple elements from a list:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WMYwon6dV9xt" }, "outputs": [], "source": [ "batch = torch_dataset[:5]\n", "print(batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "f4ba4122WKi1" }, "source": [ "As you can see, we now have an additional dimension for the time-varying elements (i.e., `x` and `y`) denoted by pattern `b`, i.e., the batch dimension. In this new, first dimension we stacked the features of the first 5 spatiotemporal graphs in the dataset.\n", "\n", "Note that this is possible only because we are assuming a fixed underlying topology, as also confirmed by the `edge_index` and `edge_weight` attributes. The explanation on how `Data` objects with different graphs are batched together is out of the scope of this notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "9j-pA8lCYVlA" }, "source": [ "### Preparing the dataset for training\n", "