{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Autoregressive modelling with DeepAR and DeepVAR\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "os.chdir(\"../../..\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import lightning.pytorch as pl\n", "from lightning.pytorch.callbacks import EarlyStopping\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import torch\n", "\n", "from pytorch_forecasting import Baseline, DeepAR, TimeSeriesDataSet\n", "from pytorch_forecasting.data import NaNLabelEncoder\n", "from pytorch_forecasting.data.examples import generate_ar_data\n", "from pytorch_forecasting.metrics import MAE, SMAPE, MultivariateNormalDistributionLoss" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load data\n", "\n", "We generate a synthetic dataset to demonstrate the network's capabilities. The data consists of a quadratic trend and a seasonality component.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | series | \n", "time_idx | \n", "value | \n", "static | \n", "date | \n", "
---|---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "-0.000000 | \n", "2 | \n", "2020-01-01 | \n", "
1 | \n", "0 | \n", "1 | \n", "-0.046501 | \n", "2 | \n", "2020-01-02 | \n", "
2 | \n", "0 | \n", "2 | \n", "-0.097796 | \n", "2 | \n", "2020-01-03 | \n", "
3 | \n", "0 | \n", "3 | \n", "-0.144397 | \n", "2 | \n", "2020-01-04 | \n", "
4 | \n", "0 | \n", "4 | \n", "-0.177954 | \n", "2 | \n", "2020-01-05 | \n", "