{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Patch Time Series Transformer in HuggingFace - Getting Started\n", "\n", "In this blog, we provide examples of how to get started with PatchTST. We first demonstrate the forecasting capability of `PatchTST` on the Electricity data. We will then demonstrate the transfer learning capability of `PatchTST` by using the previously trained model to do zero-shot forecasting on the electrical transformer (ETTh1) dataset. The zero-shot forecasting performance will denote the `test` performance of the model in the `target` domain, without any training on the target domain. Subsequently, we will do linear probing and (then) finetuning of the pretrained model on the `train` part of the target data and will validate the forecasting performance on the `test` part of the target data.\n", "\n", "The `PatchTST` model was proposed in A Time Series is Worth [64 Words: Long-term Forecasting with Transformers](https://huggingface.co/papers/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam and presented at ICLR 2023.\n", "\n", "\n", "## Quick overview of PatchTST\n", "\n", "At a high level, the model vectorizes individual time series in a batch into patches of a given size and encodes the resulting sequence of vectors via a Transformer that then outputs the prediction length forecast via an appropriate head.\n", "\n", "The model is based on two key components: \n", " 1. segmentation of time series into subseries-level patches which serve as input tokens to the Transformer; \n", " 2. channel-independence where each channel contains a single univariate time series that shares the same embedding and Transformer weights across all the series, i.e. a [global](https://doi.org/10.1016/j.ijforecast.2021.03.004) univariate model. \n", "\n", "The patching design naturally has three-fold benefit: \n", " - local semantic information is retained in the embedding; \n", " - computation and memory usage of the attention maps are quadratically reduced given the same look-back window via strides between patches; and \n", " - the model can attend longer history via a trade-off between the patch length (input vector size) and the context length (number of sequences).\n", " \n", "\n", "In addition, `PatchTST` has a modular design to seamlessly support masked time series pre-training as well as direct time series forecasting.\n", "\n", "| ![PatchTST model schematics](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/patchtst/patchtst-arch.png) |\n", "|:--:|\n", "|(a) `PatchTST` model overview where a batch of \\\\(M\\\\) time series each of length \\\\(L\\\\) are processed independently (by reshaping them into the batch dimension) via a Transformer backbone and then reshaping the resulting batch back into \\\\(M \\\\) series of prediction length \\\\(T\\\\). Each *univariate* series can be processed in a supervised fashion (b) where the patched set of vectors is used to output the full prediction length or in a self-supervised fashion (c) where masked patches are predicted. |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Installation\n", "\n", "This demo requires Hugging Face [`Transformers`](https://github.com/huggingface/transformers) for the model, and the IBM `tsfm` package for auxiliary data pre-processing.\n", "We can install both by cloning the `tsfm` repository and following the below steps.\n", "\n", "1. Clone the public IBM Time Series Foundation Model Repository [`tsfm`](https://github.com/ibm/tsfm).\n", " ```bash\n", " pip install git+https://github.com/IBM/tsfm.git\n", " ```\n", "2. Install Hugging Face [`Transformers`](https://github.com/huggingface/transformers#installation)\n", " ```bash\n", " pip install transformers\n", " ```\n", "3. Test it with the following commands in a `python` terminal.\n", " ```python\n", " from transformers import PatchTSTConfig\n", " from tsfm_public.toolkit.dataset import ForecastDFDataset\n", " ```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Forecasting on the Electricity dataset\n", "Here we train a `PatchTST` model directly on the Electricity dataset (available from https://github.com/zhouhaoyi/Informer2020), and evaluate its performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Standard\n", "import os\n", "\n", "# Third Party\n", "from transformers import (\n", " EarlyStoppingCallback,\n", " PatchTSTConfig,\n", " PatchTSTForPrediction,\n", " set_seed,\n", " Trainer,\n", " TrainingArguments,\n", ")\n", "import numpy as np\n", "import pandas as pd\n", "\n", "# First Party\n", "from tsfm_public.toolkit.dataset import ForecastDFDataset\n", "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n", "from tsfm_public.toolkit.util import select_by_index\n", "\n", "# supress some warnings\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\", module=\"torch\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ### Set seed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "set_seed(2023)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and prepare datasets\n", "\n", " In the next cell, please adjust the following parameters to suit your application:\n", " - `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n", " `pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n", " - `timestamp_column`: column name containing timestamp information, use `None` if there is no such column.\n", " - `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use `[]`.\n", " - `forecast_columns`: List of columns to be modeled\n", " - `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to `context_length` will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created so that they are contained within a single time series (i.e., a single ID).\n", " - `forecast_horizon`: Number of timestamps to forecast in the future.\n", " - `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.\n", " - `valid_start_index`, `eval_end_index`: the start and end indices in the loaded data which delineate the validation data.\n", " - `test_start_index`, `eval_end_index`: the start and end indices in the loaded data which delineate the test data.\n", " - `patch_length`: The patch length for the `PatchTST` model. It is recommended to choose a value that evenly divides `context_length`.\n", " - `num_workers`: Number of CPU workers in the PyTorch dataloader.\n", " - `batch_size`: Batch size.\n", "\n", "The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the Pandas dataframes are converted to the appropriate PyTorch dataset required for training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# The ECL data is available from https://github.com/zhouhaoyi/Informer2020?tab=readme-ov-file#data\n", "dataset_path = \"~/data/ECL.csv\"\n", "timestamp_column = \"date\"\n", "id_columns = []\n", "\n", "context_length = 512\n", "forecast_horizon = 96\n", "patch_length = 16\n", "num_workers = 16 # Reduce this if you have low number of CPU cores\n", "batch_size = 64 # Adjust according to GPU memory" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv(\n", " dataset_path,\n", " parse_dates=[timestamp_column],\n", ")\n", "forecast_columns = list(data.columns[1:])\n", "\n", "# get split\n", "num_train = int(len(data) * 0.7)\n", "num_test = int(len(data) * 0.2)\n", "num_valid = len(data) - num_train - num_test\n", "border1s = [\n", " 0,\n", " num_train - context_length,\n", " len(data) - num_test - context_length,\n", "]\n", "border2s = [num_train, num_train + num_valid, len(data)]\n", "\n", "train_start_index = border1s[0] # None indicates beginning of dataset\n", "train_end_index = border2s[0]\n", "\n", "# we shift the start of the evaluation period back by context length so that\n", "# the first evaluation timestamp is immediately following the training data\n", "valid_start_index = border1s[1]\n", "valid_end_index = border2s[1]\n", "\n", "test_start_index = border1s[2]\n", "test_end_index = border2s[2]\n", "\n", "train_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=train_start_index,\n", " end_index=train_end_index,\n", ")\n", "valid_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=valid_start_index,\n", " end_index=valid_end_index,\n", ")\n", "test_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=test_start_index,\n", " end_index=test_end_index,\n", ")\n", "\n", "time_series_preprocessor = TimeSeriesPreprocessor(\n", " timestamp_column=timestamp_column,\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " scaling=True,\n", ")\n", "time_series_preprocessor = time_series_preprocessor.train(train_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dataset = ForecastDFDataset(\n", " time_series_preprocessor.preprocess(train_data),\n", " id_columns=id_columns,\n", " timestamp_column=\"date\",\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")\n", "valid_dataset = ForecastDFDataset(\n", " time_series_preprocessor.preprocess(valid_data),\n", " id_columns=id_columns,\n", " timestamp_column=\"date\",\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")\n", "test_dataset = ForecastDFDataset(\n", " time_series_preprocessor.preprocess(test_data),\n", " id_columns=id_columns,\n", " timestamp_column=\"date\",\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configure the PatchTST model\n", "\n", "Next, we instantiate a randomly initialized `PatchTST` model with a configuration. The settings below control the different hyperparameters related to the architecture.\n", " - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is\n", " automatically set to the number for forecast columns.\n", " - `context_length`: As described above, the amount of historical data used as input to the model.\n", " - `patch_length`: The length of the patches extracted from the context window (of length `context_length`).\n", " - `patch_stride`: The stride used when extracting patches from the context window.\n", " - `random_mask_ratio`: The fraction of input patches that are completely masked for pretraining the model.\n", " - `d_model`: Dimension of the transformer layers.\n", " - `num_attention_heads`: The number of attention heads for each attention layer in the Transformer encoder.\n", " - `num_hidden_layers`: The number of encoder layers.\n", " - `ffn_dim`: Dimension of the intermediate (often referred to as feed-forward) layer in the encoder.\n", " - `dropout`: Dropout probability for all fully connected layers in the encoder.\n", " - `head_dropout`: Dropout probability used in the head of the model.\n", " - `pooling_type`: Pooling of the embedding. `\"mean\"`, `\"max\"` and `None` are supported.\n", " - `channel_attention`: Activate the channel attention block in the Transformer to allow channels to attend to each other.\n", " - `scaling`: Whether to scale the input targets via \"mean\" scaler, \"std\" scaler, or no scaler if `None`. If `True`, the\n", " scaler is set to `\"mean\"`.\n", " - `loss`: The loss function for the model corresponding to the `distribution_output` head. For parametric\n", " distributions it is the negative log-likelihood (`\"nll\"`) and for point estimates it is the mean squared\n", " error `\"mse\"`.\n", " - `pre_norm`: Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is\n", " applied after residual block.\n", " - `norm_type`: Normalization at each Transformer layer. Can be `\"BatchNorm\"` or `\"LayerNorm\"`.\n", "\n", "For full details on the parameters, we refer to the [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/patchtst#transformers.PatchTSTConfig).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = PatchTSTConfig(\n", " num_input_channels=len(forecast_columns),\n", " context_length=context_length,\n", " patch_length=patch_length,\n", " patch_stride=patch_length,\n", " prediction_length=forecast_horizon,\n", " random_mask_ratio=0.4,\n", " d_model=128,\n", " num_attention_heads=16,\n", " num_hidden_layers=3,\n", " ffn_dim=256,\n", " dropout=0.2,\n", " head_dropout=0.2,\n", " pooling_type=None,\n", " channel_attention=False,\n", " scaling=\"std\",\n", " loss=\"mse\",\n", " pre_norm=True,\n", " norm_type=\"batchnorm\",\n", ")\n", "model = PatchTSTForPrediction(config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train model\n", "\n", "Next, we can leverage the Hugging Face [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) class to train the model based on the direct forecasting strategy. We first define the [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) which lists various hyperparameters for training such as the number of epochs, learning rate and so on." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"./checkpoint/patchtst/electricity/pretrain/output/\",\n", " overwrite_output_dir=True,\n", " # learning_rate=0.001,\n", " num_train_epochs=100,\n", " do_eval=True,\n", " eval_strategy=\"epoch\",\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " dataloader_num_workers=num_workers,\n", " save_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " save_total_limit=3,\n", " logging_dir=\"./checkpoint/patchtst/electricity/pretrain/logs/\", # Make sure to specify a logging directory\n", " load_best_model_at_end=True, # Load the best model when training ends\n", " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n", " greater_is_better=False, # For loss\n", " label_names=[\"future_values\"],\n", ")\n", "\n", "# Create the early stopping callback\n", "early_stopping_callback = EarlyStoppingCallback(\n", " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n", " early_stopping_threshold=0.0001, # Minimum improvement required to consider as improvement\n", ")\n", "\n", "# define trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " callbacks=[early_stopping_callback],\n", " # compute_metrics=compute_metrics,\n", ")\n", "\n", "# pretrain\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate the model on the test set of the source domain\n", "\n", "Next, we can leverage `trainer.evaluate()` to calculate test metrics. While this is not the target metric to judge in this task, it provides a reasonable check that the pretrained model has trained properly.\n", "Note that the training and evaluation loss for `PatchTST` is the Mean Squared Error (MSE) loss. Hence, we do not separately compute the MSE metric in any of the following evaluation experiments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results = trainer.evaluate(test_dataset)\n", "print(\"Test result:\")\n", "print(results)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The MSE of `0.131` is very close to the value reported for the Electricity dataset in the original `PatchTST` paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_dir = \"patchtst/electricity/model/pretrain/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "trainer.save_model(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2: Transfer Learning from Electricity to ETTh1\n", "\n", "\n", "In this section, we will demonstrate the transfer learning capability of the `PatchTST` model.\n", "We use the model pre-trained on the Electricity dataset to do zero-shot forecasting on the ETTh1 dataset.\n", "\n", "\n", "By Transfer Learning, we mean that we first pretrain the model for a forecasting task on a `source` dataset (which we did above on the `Electricity` dataset). Then, we will use the pretrained model for zero-shot forecasting on a `target` dataset. By zero-shot, we mean that we test the performance in the `target` domain without any additional training. We hope that the model gained enough knowledge from pretraining which can be transferred to a different dataset. \n", "Subsequently, we will do linear probing and (then) finetuning of the pretrained model on the `train` split of the target data and will validate the forecasting performance on the `test` split of the target data. In this example, the source dataset is the `Electricity` dataset and the target dataset is ETTh1.\n", "\n", "### Transfer learning on ETTh1 data. \n", "All evaluations are on the `test` part of the `ETTh1` data.\n", "\n", "Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance. \n", "\n", "Step 2: Evaluate after doing linear probing. \n", "\n", "Step 3: Evaluate after doing full finetuning. \n", "\n", "### Load ETTh dataset\n", "\n", "Below, we load the `ETTh1` dataset as a Pandas dataframe. Next, we create 3 splits for training, validation, and testing. We then leverage the `TimeSeriesPreprocessor` class to prepare each split for the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = \"ETTh1\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"Loading target dataset: {dataset}\")\n", "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n", "timestamp_column = \"date\"\n", "id_columns = []\n", "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n", "train_start_index = None # None indicates beginning of dataset\n", "train_end_index = 12 * 30 * 24\n", "\n", "# we shift the start of the evaluation period back by context length so that\n", "# the first evaluation timestamp is immediately following the training data\n", "valid_start_index = 12 * 30 * 24 - context_length\n", "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n", "\n", "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n", "test_end_index = 12 * 30 * 24 + 8 * 30 * 24" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv(\n", " dataset_path,\n", " parse_dates=[timestamp_column],\n", ")\n", "\n", "train_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=train_start_index,\n", " end_index=train_end_index,\n", ")\n", "valid_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=valid_start_index,\n", " end_index=valid_end_index,\n", ")\n", "test_data = select_by_index(\n", " data,\n", " id_columns=id_columns,\n", " start_index=test_start_index,\n", " end_index=test_end_index,\n", ")\n", "\n", "time_series_preprocessor = TimeSeriesPreprocessor(\n", " timestamp_column=timestamp_column,\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " scaling=True,\n", ")\n", "time_series_preprocessor = time_series_preprocessor.train(train_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dataset = ForecastDFDataset(\n", " time_series_preprocessor.preprocess(train_data),\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")\n", "valid_dataset = ForecastDFDataset(\n", " time_series_preprocessor.preprocess(valid_data),\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")\n", "test_dataset = ForecastDFDataset(\n", " time_series_preprocessor.preprocess(test_data),\n", " id_columns=id_columns,\n", " input_columns=forecast_columns,\n", " output_columns=forecast_columns,\n", " context_length=context_length,\n", " prediction_length=forecast_horizon,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Zero-shot forecasting on ETTH\n", "\n", "As we are going to test forecasting performance out-of-the-box, we load the model which we pretrained above." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "finetune_forecast_model = PatchTSTForPrediction.from_pretrained(\n", " \"patchtst/electricity/model/pretrain/\",\n", " num_input_channels=len(forecast_columns),\n", " head_dropout=0.7,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "finetune_forecast_args = TrainingArguments(\n", " output_dir=\"./checkpoint/patchtst/transfer/finetune/output/\",\n", " overwrite_output_dir=True,\n", " learning_rate=0.0001,\n", " num_train_epochs=100,\n", " do_eval=True,\n", " eval_strategy=\"epoch\",\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " dataloader_num_workers=num_workers,\n", " report_to=\"tensorboard\",\n", " save_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " save_total_limit=3,\n", " logging_dir=\"./checkpoint/patchtst/transfer/finetune/logs/\", # Make sure to specify a logging directory\n", " load_best_model_at_end=True, # Load the best model when training ends\n", " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n", " greater_is_better=False, # For loss\n", " label_names=[\"future_values\"],\n", ")\n", "\n", "# Create a new early stopping callback with faster convergence properties\n", "early_stopping_callback = EarlyStoppingCallback(\n", " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n", " early_stopping_threshold=0.001, # Minimum improvement required to consider as improvement\n", ")\n", "\n", "finetune_forecast_trainer = Trainer(\n", " model=finetune_forecast_model,\n", " args=finetune_forecast_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " callbacks=[early_stopping_callback],\n", ")\n", "\n", "print(\"\\n\\nDoing zero-shot forecasting on target data\")\n", "result = finetune_forecast_trainer.evaluate(test_dataset)\n", "print(\"Target data zero-shot forecasting result:\")\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As can be seen, with a zero-shot forecasting approach we obtain an MSE of 0.370 which is near to the state-of-the-art result in the original `PatchTST` paper.\n", "\n", "Next, let's see how we can do by performing linear probing, which involves training a linear layer on top of a frozen pre-trained model. Linear probing is often done to test the performance of features of a pretrained model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Linear probing on ETTh1\n", "\n", "We can do a quick linear probing on the `train` part of the target data to see any possible `test` performance improvement." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Freeze the backbone of the model\n", "for param in finetune_forecast_trainer.model.model.parameters():\n", " param.requires_grad = False\n", "\n", "print(\"\\n\\nLinear probing on the target data\")\n", "finetune_forecast_trainer.train()\n", "print(\"Evaluating\")\n", "result = finetune_forecast_trainer.evaluate(test_dataset)\n", "print(\"Target data head/linear probing result:\")\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As can be seen, by only training a simple linear layer on top of the frozen backbone, the MSE decreased from 0.370 to 0.357, beating the originally reported results!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/model/linear_probe/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "finetune_forecast_trainer.save_model(save_dir)\n", "\n", "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/preprocessor/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "time_series_preprocessor.save_pretrained(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's see if we can get additional improvements by doing a full fine-tune of the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Full fine-tune on ETTh1\n", "\n", "We can do a full model fine-tune (instead of probing the last linear layer as shown above) on the `train` part of the target data to see a possible `test` performance improvement. The code looks similar to the linear probing task above, except that we are not freezing any parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reload the model\n", "finetune_forecast_model = PatchTSTForPrediction.from_pretrained(\n", " \"patchtst/electricity/model/pretrain/\",\n", " num_input_channels=len(forecast_columns),\n", " dropout=0.7,\n", " head_dropout=0.7,\n", ")\n", "finetune_forecast_trainer = Trainer(\n", " model=finetune_forecast_model,\n", " args=finetune_forecast_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " callbacks=[early_stopping_callback],\n", ")\n", "print(\"\\n\\nFinetuning on the target data\")\n", "finetune_forecast_trainer.train()\n", "print(\"Evaluating\")\n", "result = finetune_forecast_trainer.evaluate(test_dataset)\n", "print(\"Target data full finetune result:\")\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, there is only a small improvement on the ETTh1 dataset with full fine-tuning. For other datasets there may be more substantial improvements. Let's save the model anyway." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/model/fine_tuning/\"\n", "os.makedirs(save_dir, exist_ok=True)\n", "finetune_forecast_trainer.save_model(save_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "In this blog, we presented a step-by-step guide on training `PatchTST` for tasks related to forecasting and transfer learning, demonstrating various approaches for fine-tuning. We intend to facilitate easy integration of the `PatchTST` HF model for your forecasting use cases, and we hope that this content serves as a useful resource to expedite the adoption of PatchTST. Thank you for tuning in to our blog, and we hope you find this information beneficial for your projects." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 4 }