{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PatchTSMixer in HuggingFace - Getting Started\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "`PatchTSMixer` is a lightweight time-series modeling approach based on the MLP-Mixer architecture. It is proposed in [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://huggingface.co/papers/2306.09364) by IBM Research authors Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong and Jayant Kalagnanam.\n", "\n", "For effective mindshare and to promote open-sourcing - IBM Research joins hands with the HuggingFace team to release this model in the Transformers library.\n", "\n", "In the [Hugging Face implementation](https://huggingface.co/docs/transformers/main/en/model_doc/patchtsmixer), we provide PatchTSMixer’s capabilities to effortlessly facilitate lightweight mixing across patches, channels, and hidden features for effective multivariate time-series modeling. It also supports various attention mechanisms starting from simple gated attention to more complex self-attention blocks that can be customized accordingly. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification, and regression.\n", "\n", "`PatchTSMixer` outperforms state-of-the-art MLP and Transformer models in forecasting by a considerable margin of 8-60%. It also outperforms the latest strong benchmarks of Patch-Transformer models (by 1-2%) with a significant reduction in memory and runtime (2-3X). For more details, refer to the [paper](https://arxiv.org/pdf/2306.09364.pdf).\n", "\n", "In this blog, we will demonstrate examples of getting started with PatchTSMixer. We will first demonstrate the forecasting capability of `PatchTSMixer` on the Electricity dataset. We will then demonstrate the transfer learning capability of PatchTSMixer by using the model trained on Electricity to do zero-shot forecasting on the `ETTH2` dataset.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation\n", "This demo requires Hugging Face [`Transformers`](https://github.com/huggingface/transformers) for the model and the IBM `tsfm` package for auxiliary data pre-processing.\n", "Both can be installed by following the steps below.\n", "\n", "1. Install IBM Time Series Foundation Model Repository [`tsfm`](https://github.com/ibm/tsfm).\n", "```\n", "pip install git+https://github.com:IBM/tsfm.git\n", "```\n", "2. Install Hugging Face [`Transformers`](https://github.com/huggingface/transformers#installation)\n", "```\n", "pip install transformers\n", "```\n", "3. Test it with the following commands in a `python` terminal.\n", "```\n", "from transformers import PatchTSMixerConfig\n", "from tsfm_public.toolkit.dataset import ForecastDFDataset\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install git+https://github.com:IBM/tsfm.git transformers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import PatchTSMixerConfig\n", "from tsfm_public.toolkit.dataset import ForecastDFDataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Forecasting on Electricity dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "2023-12-11 01:25:50.313015: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-12-11 01:25:50.313102: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-12-11 01:25:50.313132: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2023-12-11 01:25:51.234452: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import os\n", "import random\n", "\n", "from transformers import (\n", " EarlyStoppingCallback,\n", " PatchTSMixerConfig,\n", " PatchTSMixerForPrediction,\n", " Trainer,\n", " TrainingArguments,\n", ")\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "from tsfm_public.toolkit.dataset import ForecastDFDataset\n", "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n", "from tsfm_public.toolkit.util import select_by_index" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ### Set seed" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from transformers import set_seed\n", "\n", "set_seed(42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and prepare datasets\n", "\n", "In the next cell, please adjust the following parameters to suit your application:\n", "- `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n", "`pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n", "- `timestamp_column`: column name containing timestamp information, use `None` if there is no such column.\n", "- `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use `[]`.\n", "- `forecast_columns`: List of columns to be modeled.\n", "- `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to\n", "`context_length` will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created\n", "so that they are contained within a single time series (i.e., a single ID).\n", "- `forecast_horizon`: Number of timestamps to forecast in the future.\n", "- `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.\n", "- `valid_start_index`, `valid_end_index`: the start and end indices in the loaded data which delineate the validation data.\n", "- `test_start_index`, `test_end_index`: the start and end indices in the loaded data which delineate the test data.\n", "- `num_workers`: Number of CPU workers in the PyTorch dataloader.\n", "- `batch_size`: Batch size.\n", "The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the Pandas dataframes are converted to the appropriate PyTorch dataset required for training." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "PRETRAIN_AGAIN = True\n", "# Download ECL data from https://github.com/zhouhaoyi/Informer2020\n", "dataset_path = \"~/Downloads/ECL.csv\"\n", "timestamp_column = \"date\"\n", "id_columns = []\n", "\n", "context_length = 512\n", "forecast_horizon = 96\n", "patch_length = 8\n", "num_workers = 16 # Reduce this if you have low number of CPU cores\n", "batch_size = 64 # Adjust according to GPU memory" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "if PRETRAIN_AGAIN:\n", " data = pd.read_csv(\n", " dataset_path,\n", " parse_dates=[timestamp_column],\n", " )\n", " forecast_columns = list(data.columns[1:])\n", "\n", " # get split\n", " num_train = int(len(data) * 0.7)\n", " num_test = int(len(data) * 0.2)\n", " num_valid = len(data) - num_train - num_test\n", " border1s = [\n", " 0,\n", " num_train - context_length,\n", " len(data) - num_test - context_length,\n", " ]\n", " border2s = [num_train, num_train + num_valid, len(data)]\n", "\n", " train_start_index = border1s[0] # None indicates beginning of dataset\n", " train_end_index = border2s[0]\n", "\n", " # we shift the start of the evaluation period back by context length so that\n", " # the first evaluation timestamp is immediately following the training data\n", " valid_start_index = border1s[1]\n", " valid_end_index = border2s[1]\n", "\n", " test_start_index = border1s[2]\n", " test_end_index = border2s[2]\n", "\n", " train_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=train_start_index,\n", " end_index=train_end_index,\n", " )\n", " valid_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=valid_start_index,\n", " end_index=valid_end_index,\n", " )\n", " test_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=test_start_index,\n", " end_index=test_end_index,\n", " )\n", "\n", " tsp = TimeSeriesPreprocessor(\n", " context_length=context_length,\n", " timestamp_column=timestamp_column,\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " scaling=True,\n", " )\n", " tsp.train(train_data)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "if PRETRAIN_AGAIN:\n", " train_dataset = ForecastDFDataset(\n", " tsp.preprocess(train_data),\n", " id_columns=id_columns,\n", " timestamp_column=\"date\",\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", " )\n", " valid_dataset = ForecastDFDataset(\n", " tsp.preprocess(valid_data),\n", " id_columns=id_columns,\n", " timestamp_column=\"date\",\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", " )\n", " test_dataset = ForecastDFDataset(\n", " tsp.preprocess(test_data),\n", " id_columns=id_columns,\n", " timestamp_column=\"date\",\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## Configure the PatchTSMixer model\n", "\n", "Next, we instantiate a randomly initialized PatchTSMixer model with a configuration. The settings below control the different hyperparameters related to the architecture.\n", " - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is\n", " automatically set to the number for forecast columns.\n", " - `context_length`: As described above, the amount of historical data used as input to the model.\n", " - `prediction_length`: This is same as the forecast horizon as described above.\n", " - `patch_length`: The patch length for the `PatchTSMixer` model. It is recommended to choose a value that evenly divides `context_length`.\n", " - `patch_stride`: The stride used when extracting patches from the context window.\n", " - `d_model`: Hidden feature dimension of the model.\n", " - `num_layers`: The number of model layers.\n", " - `dropout`: Dropout probability for all fully connected layers in the encoder.\n", " - `head_dropout`: Dropout probability used in the head of the model.\n", " - `mode`: PatchTSMixer operating mode. \"common_channel\"/\"mix_channel\". Common-channel works in channel-independent mode. For pretraining, use \"common_channel\".\n", " - `scaling`: Per-widow standard scaling. Recommended value: \"std\".\n", "\n", "For full details on the parameters, we refer to the [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/patchtsmixer#transformers.PatchTSMixerConfig).\n", "\n", "We recommend that you only adjust the values in the next cell." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "if PRETRAIN_AGAIN:\n", " config = PatchTSMixerConfig(\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", " patch_length=patch_length,\n", " num_input_channels=len(forecast_columns),\n", " patch_stride=patch_length,\n", " d_model=16,\n", " num_layers=8,\n", " expansion_factor=2,\n", " dropout=0.2,\n", " head_dropout=0.2,\n", " mode=\"common_channel\",\n", " scaling=\"std\",\n", " )\n", " model = PatchTSMixerForPrediction(config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## Train model\n", "\n", " Next, we can leverage the Hugging Face [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) class to train the model based on the direct forecasting strategy. We first define the [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) which lists various hyperparameters regarding training such as the number of epochs, learning rate, and so on." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='2450' max='7000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [2450/7000 21:35 < 40:08, 1.89 it/s, Epoch 35/100]\n", " </div>\n", " <table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>Epoch</th>\n", " <th>Training Loss</th>\n", " <th>Validation Loss</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>1</td>\n", " <td>0.247100</td>\n", " <td>0.141067</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.168600</td>\n", " <td>0.127757</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.156500</td>\n", " <td>0.122327</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.150300</td>\n", " <td>0.118918</td>\n", " </tr>\n", " <tr>\n", " <td>5</td>\n", " <td>0.146000</td>\n", " <td>0.116496</td>\n", " </tr>\n", " <tr>\n", " <td>6</td>\n", " <td>0.143100</td>\n", " <td>0.114968</td>\n", " </tr>\n", " <tr>\n", " <td>7</td>\n", " <td>0.140800</td>\n", " <td>0.113678</td>\n", " </tr>\n", " <tr>\n", " <td>8</td>\n", " <td>0.139200</td>\n", " <td>0.113057</td>\n", " </tr>\n", " <tr>\n", " <td>9</td>\n", " <td>0.137900</td>\n", " <td>0.112405</td>\n", " </tr>\n", " <tr>\n", " <td>10</td>\n", " <td>0.136900</td>\n", " <td>0.112225</td>\n", " </tr>\n", " <tr>\n", " <td>11</td>\n", " <td>0.136100</td>\n", " <td>0.112087</td>\n", " </tr>\n", " <tr>\n", " <td>12</td>\n", " <td>0.135400</td>\n", " <td>0.112330</td>\n", " </tr>\n", " <tr>\n", " <td>13</td>\n", " <td>0.134700</td>\n", " <td>0.111778</td>\n", " </tr>\n", " <tr>\n", " <td>14</td>\n", " <td>0.134100</td>\n", " <td>0.111702</td>\n", " </tr>\n", " <tr>\n", " <td>15</td>\n", " <td>0.133700</td>\n", " <td>0.110964</td>\n", " </tr>\n", " <tr>\n", " <td>16</td>\n", " <td>0.133100</td>\n", " <td>0.111164</td>\n", " </tr>\n", " <tr>\n", " <td>17</td>\n", " <td>0.132800</td>\n", " <td>0.111063</td>\n", " </tr>\n", " <tr>\n", " <td>18</td>\n", " <td>0.132400</td>\n", " <td>0.111088</td>\n", " </tr>\n", " <tr>\n", " <td>19</td>\n", " <td>0.132100</td>\n", " <td>0.110905</td>\n", " </tr>\n", " <tr>\n", " <td>20</td>\n", " <td>0.131800</td>\n", " <td>0.110844</td>\n", " </tr>\n", " <tr>\n", " <td>21</td>\n", " <td>0.131300</td>\n", " <td>0.110831</td>\n", " </tr>\n", " <tr>\n", " <td>22</td>\n", " <td>0.131100</td>\n", " <td>0.110278</td>\n", " </tr>\n", " <tr>\n", " <td>23</td>\n", " <td>0.130700</td>\n", " <td>0.110591</td>\n", " </tr>\n", " <tr>\n", " <td>24</td>\n", " <td>0.130600</td>\n", " <td>0.110319</td>\n", " </tr>\n", " <tr>\n", " <td>25</td>\n", " <td>0.130300</td>\n", " <td>0.109900</td>\n", " </tr>\n", " <tr>\n", " <td>26</td>\n", " <td>0.130000</td>\n", " <td>0.109982</td>\n", " </tr>\n", " <tr>\n", " <td>27</td>\n", " <td>0.129900</td>\n", " <td>0.109975</td>\n", " </tr>\n", " <tr>\n", " <td>28</td>\n", " <td>0.129600</td>\n", " <td>0.110128</td>\n", " </tr>\n", " <tr>\n", " <td>29</td>\n", " <td>0.129300</td>\n", " <td>0.109995</td>\n", " </tr>\n", " <tr>\n", " <td>30</td>\n", " <td>0.129100</td>\n", " <td>0.109868</td>\n", " </tr>\n", " <tr>\n", " <td>31</td>\n", " <td>0.129000</td>\n", " <td>0.109928</td>\n", " </tr>\n", " <tr>\n", " <td>32</td>\n", " <td>0.128700</td>\n", " <td>0.109823</td>\n", " </tr>\n", " <tr>\n", " <td>33</td>\n", " <td>0.128500</td>\n", " <td>0.109863</td>\n", " </tr>\n", " <tr>\n", " <td>34</td>\n", " <td>0.128400</td>\n", " <td>0.109794</td>\n", " </tr>\n", " <tr>\n", " <td>35</td>\n", " <td>0.128100</td>\n", " <td>0.109945</td>\n", " </tr>\n", " </tbody>\n", "</table><p>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] } ], "source": [ "if PRETRAIN_AGAIN:\n", " training_args = TrainingArguments(\n", " output_dir=\"./checkpoint/patchtsmixer/electricity/pretrain/output/\",\n", " overwrite_output_dir=True,\n", " learning_rate=0.001,\n", " num_train_epochs=100, # For a quick test of this notebook, set it to 1\n", " do_eval=True,\n", " evaluation_strategy=\"epoch\",\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " dataloader_num_workers=num_workers,\n", " report_to=\"tensorboard\",\n", " save_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " save_total_limit=3,\n", " logging_dir=\"./checkpoint/patchtsmixer/electricity/pretrain/logs/\", # Make sure to specify a logging directory\n", " load_best_model_at_end=True, # Load the best model when training ends\n", " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n", " greater_is_better=False, # For loss\n", " label_names=[\"future_values\"],\n", " # max_steps=20,\n", " )\n", "\n", " # Create the early stopping callback\n", " early_stopping_callback = EarlyStoppingCallback(\n", " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n", " early_stopping_threshold=0.0001, # Minimum improvement required to consider as improvement\n", " )\n", "\n", " # define trainer\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " callbacks=[early_stopping_callback],\n", " )\n", "\n", " # pretrain\n", " trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## Evaluate the model on the test set.\n", "\n", "**Note that the training and evaluation loss for PatchTSMixer is the Mean Squared Error (MSE) loss. Hence, we do not separately compute the MSE metric in any of the following evaluation experiments.**\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='21' max='21' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [21/21 00:03]\n", " </div>\n", " " ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Test result:\n", "{'eval_loss': 0.12884521484375, 'eval_runtime': 5.7532, 'eval_samples_per_second': 897.763, 'eval_steps_per_second': 3.65, 'epoch': 35.0}\n" ] } ], "source": [ "if PRETRAIN_AGAIN:\n", " results = trainer.evaluate(test_dataset)\n", " print(\"Test result:\")\n", " print(results)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We get MSE score of 0.128 which is the SOTA result on the Electricity data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## Save model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "if PRETRAIN_AGAIN:\n", " save_dir = \"patchtsmixer/electricity/model/pretrain/\"\n", " os.makedirs(save_dir, exist_ok=True)\n", " trainer.save_model(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Part 2: Transfer Learning from Electricity to `ETTH2`\n", "\n", "In this section, we will demonstrate the transfer learning capability of the `PatchTSMixer` model.\n", "We use the model pre-trained on the Electricity dataset to do zero-shot forecasting on the `ETTH2` dataset.\n", "\n", "\n", "By Transfer Learning, we mean that we first pretrain the model for a forecasting task on a `source` dataset (which we did above on the `Electricity` dataset). Then, we will use the\n", " pretrained model for zero-shot forecasting on a `target` dataset. By zero-shot, we mean that we test the performance in the `target` domain without any additional training. We hope that the model gained enough knowledge from pretraining which can be transferred to a different dataset. \n", " \n", " Subsequently, we will do linear probing and (then) finetuning of the pretrained model on the `train` split of the target data, and will validate the forecasting performance on the `test` split of the target data. In this example, the source dataset is the Electricity dataset and the target dataset is `ETTH2`.\n", "\n", "## Transfer Learning on `ETTh2` data\n", "\n", "All evaluations are on the `test` part of the `ETTh2` data:\n", "\n", "Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance. \n", "Step 2: Evaluate after doing linear probing. \n", "Step 3: Evaluate after doing full finetuning. \n", "\n", "### Load `ETTh2` dataset\n", "\n", "Below, we load the `ETTh2` dataset as a Pandas dataframe. Next, we create 3 splits for training, validation and testing. We then leverage the `TimeSeriesPreprocessor` class to prepare each split for the model." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "dataset = \"ETTh2\"" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading target dataset: ETTh2\n" ] } ], "source": [ "print(f\"Loading target dataset: {dataset}\")\n", "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n", "timestamp_column = \"date\"\n", "id_columns = []\n", "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n", "train_start_index = None # None indicates beginning of dataset\n", "train_end_index = 12 * 30 * 24\n", "\n", "# we shift the start of the evaluation period back by context length so that\n", "# the first evaluation timestamp is immediately following the training data\n", "valid_start_index = 12 * 30 * 24 - context_length\n", "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n", "\n", "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n", "test_end_index = 12 * 30 * 24 + 8 * 30 * 24" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TimeSeriesPreprocessor {\n", " \"context_length\": 64,\n", " \"feature_extractor_type\": \"TimeSeriesPreprocessor\",\n", " \"id_columns\": [],\n", " \"input_columns\": [\n", " \"HUFL\",\n", " \"HULL\",\n", " \"MUFL\",\n", " \"MULL\",\n", " \"LUFL\",\n", " \"LULL\",\n", " \"OT\"\n", " ],\n", " \"output_columns\": [\n", " \"HUFL\",\n", " \"HULL\",\n", " \"MUFL\",\n", " \"MULL\",\n", " \"LUFL\",\n", " \"LULL\",\n", " \"OT\"\n", " ],\n", " \"prediction_length\": null,\n", " \"processor_class\": \"TimeSeriesPreprocessor\",\n", " \"scaler_dict\": {\n", " \"0\": {\n", " \"copy\": true,\n", " \"feature_names_in_\": [\n", " \"HUFL\",\n", " \"HULL\",\n", " \"MUFL\",\n", " \"MULL\",\n", " \"LUFL\",\n", " \"LULL\",\n", " \"OT\"\n", " ],\n", " \"mean_\": [\n", " 41.53683496078959,\n", " 12.273452896210882,\n", " 46.60977329964991,\n", " 10.526153112865156,\n", " 1.1869920139097505,\n", " -2.373217913729173,\n", " 26.872023494265697\n", " ],\n", " \"n_features_in_\": 7,\n", " \"n_samples_seen_\": 8640,\n", " \"scale_\": [\n", " 10.448841072588488,\n", " 4.587112566531959,\n", " 16.858190332598408,\n", " 3.018605566682919,\n", " 4.641011217319063,\n", " 8.460910779279644,\n", " 11.584718923414682\n", " ],\n", " \"var_\": [\n", " 109.17827976021215,\n", " 21.04160169803542,\n", " 284.19858129011436,\n", " 9.111979567209104,\n", " 21.538985119281367,\n", " 71.58701121493046,\n", " 134.20571253452223\n", " ],\n", " \"with_mean\": true,\n", " \"with_std\": true\n", " }\n", " },\n", " \"scaling\": true,\n", " \"time_series_task\": \"forecasting\",\n", " \"timestamp_column\": \"date\"\n", "}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = pd.read_csv(\n", " dataset_path,\n", " parse_dates=[timestamp_column],\n", ")\n", "\n", "train_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=train_start_index,\n", " end_index=train_end_index,\n", ")\n", "valid_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=valid_start_index,\n", " end_index=valid_end_index,\n", ")\n", "test_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=test_start_index,\n", " end_index=test_end_index,\n", ")\n", "\n", "tsp = TimeSeriesPreprocessor(\n", " timestamp_column=timestamp_column,\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " scaling=True,\n", ")\n", "tsp.train(train_data)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "train_dataset = ForecastDFDataset(\n", " tsp.preprocess(train_data),\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")\n", "valid_dataset = ForecastDFDataset(\n", " tsp.preprocess(valid_data),\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")\n", "test_dataset = ForecastDFDataset(\n", " tsp.preprocess(test_data),\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Zero-shot forecasting on `ETTh2`\n", "\n", "As we are going to test forecasting performance out-of-the-box, we load the model which we pretrained above." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pretrained model\n", "Done\n" ] } ], "source": [ "print(\"Loading pretrained model\")\n", "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n", " \"patchtsmixer/electricity/model/pretrain/\"\n", ")\n", "print(\"Done\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Doing zero-shot forecasting on target data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='22' max='11' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [11/11 02:52]\n", " </div>\n", " " ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Target data zero-shot forecasting result:\n", "{'eval_loss': 0.3038313388824463, 'eval_runtime': 1.8364, 'eval_samples_per_second': 1516.562, 'eval_steps_per_second': 5.99}\n" ] } ], "source": [ "finetune_forecast_args = TrainingArguments(\n", " output_dir=\"./checkpoint/patchtsmixer/transfer/finetune/output/\",\n", " overwrite_output_dir=True,\n", " learning_rate=0.0001,\n", " num_train_epochs=100,\n", " do_eval=True,\n", " evaluation_strategy=\"epoch\",\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " dataloader_num_workers=num_workers,\n", " report_to=\"tensorboard\",\n", " save_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " save_total_limit=3,\n", " logging_dir=\"./checkpoint/patchtsmixer/transfer/finetune/logs/\", # Make sure to specify a logging directory\n", " load_best_model_at_end=True, # Load the best model when training ends\n", " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n", " greater_is_better=False, # For loss\n", ")\n", "\n", "# Create a new early stopping callback with faster convergence properties\n", "early_stopping_callback = EarlyStoppingCallback(\n", " early_stopping_patience=5, # Number of epochs with no improvement after which to stop\n", " early_stopping_threshold=0.001, # Minimum improvement required to consider as improvement\n", ")\n", "\n", "finetune_forecast_trainer = Trainer(\n", " model=finetune_forecast_model,\n", " args=finetune_forecast_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " callbacks=[early_stopping_callback],\n", ")\n", "\n", "print(\"\\n\\nDoing zero-shot forecasting on target data\")\n", "result = finetune_forecast_trainer.evaluate(test_dataset)\n", "print(\"Target data zero-shot forecasting result:\")\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As can be seen, we get a mean-squared error (MSE) of 0.3 zero-shot which is near to the state-of-the-art result.\n", "\n", "Next, let's see how we can do by performing linear probing, which involves training a linear classifier on top of a frozen pre-trained model. Linear probing is often done to test the performance of features of a pretrained model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear probing on `ETTh2`\n", "We can do a quick linear probing on the `train` part of the target data to see any possible `test` performance improvement. " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Linear probing on the target data\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='416' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [ 416/3200 01:01 < 06:53, 6.73 it/s, Epoch 13/100]\n", " </div>\n", " <table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>Epoch</th>\n", " <th>Training Loss</th>\n", " <th>Validation Loss</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>1</td>\n", " <td>0.447000</td>\n", " <td>0.216436</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.438600</td>\n", " <td>0.215667</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.429400</td>\n", " <td>0.215104</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.422500</td>\n", " <td>0.213820</td>\n", " </tr>\n", " <tr>\n", " <td>5</td>\n", " <td>0.418500</td>\n", " <td>0.213585</td>\n", " </tr>\n", " <tr>\n", " <td>6</td>\n", " <td>0.415000</td>\n", " <td>0.213016</td>\n", " </tr>\n", " <tr>\n", " <td>7</td>\n", " <td>0.412000</td>\n", " <td>0.213067</td>\n", " </tr>\n", " <tr>\n", " <td>8</td>\n", " <td>0.412400</td>\n", " <td>0.211993</td>\n", " </tr>\n", " <tr>\n", " <td>9</td>\n", " <td>0.405900</td>\n", " <td>0.212460</td>\n", " </tr>\n", " <tr>\n", " <td>10</td>\n", " <td>0.405300</td>\n", " <td>0.211772</td>\n", " </tr>\n", " <tr>\n", " <td>11</td>\n", " <td>0.406200</td>\n", " <td>0.212154</td>\n", " </tr>\n", " <tr>\n", " <td>12</td>\n", " <td>0.400600</td>\n", " <td>0.212082</td>\n", " </tr>\n", " <tr>\n", " <td>13</td>\n", " <td>0.405300</td>\n", " <td>0.211458</td>\n", " </tr>\n", " </tbody>\n", "</table><p>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluating\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='11' max='11' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [11/11 00:00]\n", " </div>\n", " " ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Target data head/linear probing result:\n", "{'eval_loss': 0.27119266986846924, 'eval_runtime': 1.7621, 'eval_samples_per_second': 1580.478, 'eval_steps_per_second': 6.242, 'epoch': 13.0}\n" ] } ], "source": [ "# Freeze the backbone of the model\n", "for param in finetune_forecast_trainer.model.model.parameters():\n", " param.requires_grad = False\n", "\n", "print(\"\\n\\nLinear probing on the target data\")\n", "finetune_forecast_trainer.train()\n", "print(\"Evaluating\")\n", "result = finetune_forecast_trainer.evaluate(test_dataset)\n", "print(\"Target data head/linear probing result:\")\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, "source": [ "As can be seen, by training a simple linear layer on top of the frozen backbone, the MSE decreased from 0.3 to 0.271 achieving state-of-the-art results." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['patchtsmixer/electricity/model/transfer/ETTh2/preprocessor/preprocessor_config.json']" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/model/linear_probe/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "finetune_forecast_trainer.save_model(save_dir)\n", "\n", "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/preprocessor/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "tsp.save_pretrained(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's see if we get any more improvements by doing a full finetune of the model on the target dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Full finetuning on `ETTh2`\n", "\n", "We can do a full model finetune (instead of probing the last linear layer as shown above) on the `train` part of the target data to see a possible `test` performance improvement. The code looks similar to the linear probing task above, except that we are not freezing any parameters." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Finetuning on the target data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='288' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [ 288/3200 00:44 < 07:34, 6.40 it/s, Epoch 9/100]\n", " </div>\n", " <table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>Epoch</th>\n", " <th>Training Loss</th>\n", " <th>Validation Loss</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>1</td>\n", " <td>0.432900</td>\n", " <td>0.215200</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.416700</td>\n", " <td>0.210919</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.401400</td>\n", " <td>0.209932</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.392900</td>\n", " <td>0.208808</td>\n", " </tr>\n", " <tr>\n", " <td>5</td>\n", " <td>0.388100</td>\n", " <td>0.209692</td>\n", " </tr>\n", " <tr>\n", " <td>6</td>\n", " <td>0.375900</td>\n", " <td>0.209546</td>\n", " </tr>\n", " <tr>\n", " <td>7</td>\n", " <td>0.370000</td>\n", " <td>0.210207</td>\n", " </tr>\n", " <tr>\n", " <td>8</td>\n", " <td>0.367000</td>\n", " <td>0.211601</td>\n", " </tr>\n", " <tr>\n", " <td>9</td>\n", " <td>0.359400</td>\n", " <td>0.211405</td>\n", " </tr>\n", " </tbody>\n", "</table><p>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n", "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Evaluating\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='11' max='11' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [11/11 00:00]\n", " </div>\n", " " ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Target data full finetune result:\n", "{'eval_loss': 0.2734043300151825, 'eval_runtime': 1.5853, 'eval_samples_per_second': 1756.725, 'eval_steps_per_second': 6.939, 'epoch': 9.0}\n" ] } ], "source": [ "# Reload the model\n", "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n", " \"patchtsmixer/electricity/model/pretrain/\"\n", ")\n", "finetune_forecast_trainer = Trainer(\n", " model=finetune_forecast_model,\n", " args=finetune_forecast_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " callbacks=[early_stopping_callback],\n", ")\n", "print(\"\\n\\nFinetuning on the target data\")\n", "finetune_forecast_trainer.train()\n", "print(\"Evaluating\")\n", "result = finetune_forecast_trainer.evaluate(test_dataset)\n", "print(\"Target data full finetune result:\")\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, there is not much improvement by doing full finetuning. Let's save the model anyway." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/model/fine_tuning/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "finetune_forecast_trainer.save_model(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary \n", "In this blog, we presented a step-by-step guide on leveraging PatchTSMixer for tasks related to forecasting and transfer learning. We intend to facilitate the seamless integration of the PatchTSMixer HF model for your forecasting use cases. We trust that this content serves as a useful resource to expedite your adoption of PatchTSMixer. Thank you for tuning in to our blog, and we hope you find this information beneficial for your projects." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 4 }