{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Time Series Forecasting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we will demonstrate how to build a model for time series forecasting in NumPyro. Specifically, we will replicate the **Seasonal, Global Trend (SGT)** model from the [Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications](https://cran.r-project.org/web/packages/Rlgt/index.html) package. The time series data that we will use for this tutorial is the **lynx** dataset, which contains annual numbers of lynx trappings from 1821 to 1934 in Canada." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "from IPython.display import set_matplotlib_formats\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "import jax.numpy as np\n", "from jax import lax, random, vmap\n", "from jax.nn import softmax\n", "\n", "import numpyro; numpyro.set_host_device_count(4)\n", "import numpyro.distributions as dist\n", "from numpyro.diagnostics import autocorrelation, hpdi\n", "from numpyro import handlers\n", "from numpyro.infer import MCMC, NUTS, Predictive\n", "\n", "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats('svg')\n", "\n", "assert numpyro.__version__.startswith('0.2.4')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, lets import and take a look at the dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Length of time series: 114\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "URL = \"https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/datasets/lynx.csv\"\n", "lynx = pd.read_csv(URL, index_col=0)\n", "data = lynx[\"value\"].values\n", "print(\"Length of time series:\", data.shape[0])\n", "plt.figure(figsize=(8, 4))\n", "plt.plot(lynx[\"time\"], data)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The time series has a length of 114 (a data point for each year), and by looking at the plot, we can observe [seasonality](https://en.wikipedia.org/wiki/Seasonality) in this dataset, which is the recurrence of similar patterns at specific time periods. e.g. in this dataset, we observe a cyclical pattern every 10 years, but there is also a less obvious but clear spike in the number of trappings every 40 years. Let us see if we can model this effect in NumPyro.\n", "\n", "In this tutorial, we will use the first 80 values for training and the last 34 values for testing." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "y_train, y_test = np.array(data[:80], dtype=np.float32), data[80:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model we are going to use is called **Seasonal, Global Trend**, which when tested on 3003 time series of the [M-3 competition](https://forecasters.org/resources/time-series-data/m3-competition/), has been known to outperform other models originally participating in the competition:\n", "\n", "$$\n", "\\begin{align}\n", "\\text{exp_val}_{t} &= \\text{level}_{t-1} + \\text{coef_trend} \\times \\text{level}_{t-1}^{\\text{pow_trend}} + \\text{s}_t \\times \\text{level}_{t-1}^{\\text{pow_season}}, \\\\\n", "\\sigma_{t} &= \\sigma \\times \\text{exp_val}_{t}^{\\text{powx}} + \\text{offset}, \\\\\n", "y_{t} &\\sim \\text{StudentT}(\\nu, \\text{exp_val}_{t}, \\sigma_{t})\n", "\\end{align}\n", "$$\n", "\n", ", where `level` and `s` follows the following recursion rules:\n", "\n", "$$\n", "\\begin{align}\n", "\\text{level_p} &=\n", "\\begin{cases}\n", "y_t - \\text{s}_t \\times \\text{level}_{t-1}^{\\text{pow_season}} & \\text{if } t \\le \\text{seasonality}, \\\\ \n", "\\text{Average} \\left[y(t - \\text{seasonality} + 1), \\ldots, y(t)\\right] & \\text{otherwise},\n", "\\end{cases} \\\\\n", "\\text{level}_{t} &= \\text{level_sm} \\times \\text{level_p} + (1 - \\text{level_sm}) \\times \\text{level}_{t-1}, \\\\\n", "\\text{s}_{t + \\text{seasonality}} &= \\text{s_sm} \\times \\frac{y_{t} - \\text{level}_{t}}{\\text{level}_{t-1}^{\\text{pow_trend}}}\n", "+ (1 - \\text{s_sm}) \\times \\text{s}_{t}.\n", "\\end{align}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A more detailed explanation for SGT model can be found in [this vignette](https://cran.r-project.org/web/packages/Rlgt/vignettes/GT_models.html) from the authors of the Rlgt package. Here we summarize the core ideas of this model:\n", "\n", "+ [Student's t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution), which has heavier tails than normal distribution, is used for the likelihood.\n", "+ The expected value `exp_val` consists of a trending component and a seasonal component:\n", " - The trend is governed by the map $x \\mapsto x + ax^b$, where $x$ is `level`, $a$ is `coef_trend`, and $b$ is `pow_trend`. Note that when $b \\sim 0$, the trend is linear with $a$ is the slope, and when $b \\sim 1$, the trend is exponential with $a$ is the rate. So that function can cover a large family of trend.\n", " - When time changes, `level` and `s` are updated to new values. Coefficients `level_sm` and `s_sm` are used to make the transition smoothly.\n", "+ When `powx` is near $0$, the error $\\sigma_t$ will be nearly constant while when `powx` is near $1$, the error will be propotional to the expected value.\n", "+ There are several varieties of SGT. In this tutorial, we use generalized seasonality and seasonal average method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `level` and `s` are updated recursively while we collect the expected value at each time step. NumPyro uses [JAX](https://github.com/google/jax) in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the tree building process. However, doing so using Python's `for` loop in the model will result in a long compilation time for the model, so we use `jax.lax.scan` instead. A detailed explanation for using this utility can be found in [lax.scan documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan). Here we use it to collect expected values while the pair `(level, s)` plays the role of carrying state." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def scan_exp_val(\n", " y, init_s, level_sm, s_sm, coef_trend, pow_trend, pow_season, future=0):\n", " N = y.shape[0]\n", " duration = N + future\n", " seasonality = init_s.shape[0]\n", "\n", " def scan_fn(carry, t):\n", " level, s, moving_sum = carry\n", " season = s[0] * level ** pow_season\n", " exp_val = level + coef_trend * level ** pow_trend + season\n", " exp_val = np.clip(exp_val, a_min=0)\n", " # use exoected vale when forecasting\n", " y_t = np.where(t >= N, exp_val, y[t])\n", "\n", " moving_sum = moving_sum + y[t] - \\\n", " np.where(t >= seasonality, y[t - seasonality], 0.)\n", " level_p = np.where(t >= seasonality, moving_sum / seasonality, y_t - season)\n", " level = level_sm * level_p + (1 - level_sm) * level\n", " level = np.clip(level, a_min=0)\n", "\n", " new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]\n", " # repeat s when forecasting\n", " new_s = np.where(t >= N, s[0], new_s)\n", " s = np.concatenate([s[1:], new_s[None]], axis=0)\n", " return (level, s, moving_sum), exp_val\n", "\n", " level_init = y[0]\n", " s_init = np.concatenate([init_s[1:], init_s[:1]], axis=0)\n", " moving_sum = level_init\n", " (last_level, last_s, moving_sum), exp_vals = lax.scan(\n", " scan_fn, (level_init, s_init, moving_sum), np.arange(1, duration))\n", " return exp_vals, last_level, last_s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With our utility function defined above, we are ready to specify the model using *NumPyro* primitives. In NumPyro, we use the primitive `sample(name, prior)` to declare a latent random variable with a corresponding `prior`. These primitives can have custom interpretations depending on the effect handlers that are used by NumPyro inference algorithms in the backend. e.g. we can condition on specific values using the `substitute` handler, or record values at these sample sites in the execution trace using the `trace` handler. Note that these details are not important for specifying the model, or running inference, but curious readers are encouraged to read the [tutorial on effect handlers](http://pyro.ai/examples/effect_handlers.html) in Pyro." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def sgt(y, seasonality, future=0):\n", " # heuristically, standard derivation of Cauchy prior depends on\n", " # the max value of data\n", " cauchy_sd = np.max(y) / 150\n", "\n", " nu = numpyro.sample(\"nu\", dist.Uniform(2, 20))\n", " powx = numpyro.sample(\"powx\", dist.Uniform(0, 1))\n", " sigma = numpyro.sample(\"sigma\", dist.HalfCauchy(cauchy_sd))\n", " offset_sigma = numpyro.sample(\"offset_sigma\", dist.TruncatedCauchy(\n", " low=1e-10, loc=1e-10, scale=cauchy_sd))\n", "\n", " coef_trend = numpyro.sample(\"coef_trend\", dist.Cauchy(0, cauchy_sd))\n", " pow_trend_beta = numpyro.sample(\"pow_trend_beta\", dist.Beta(1, 1))\n", " # pow_trend takes values from -0.5 to 1\n", " pow_trend = 1.5 * pow_trend_beta - 0.5\n", " pow_season = numpyro.sample(\"pow_season\", dist.Beta(1, 1))\n", "\n", " level_sm = numpyro.sample(\"level_sm\", dist.Beta(1, 2))\n", " s_sm = numpyro.sample(\"s_sm\", dist.Uniform(0, 1))\n", " init_s = numpyro.sample(\"init_s\", dist.Cauchy(0, y[:seasonality] * 0.3))\n", "\n", " exp_val, last_level, last_s = scan_exp_val(\n", " y, init_s, level_sm, s_sm, coef_trend, pow_trend, pow_season, future=future)\n", " if future == 0: # training\n", " omega = sigma * exp_val ** powx + offset_sigma\n", " numpyro.sample(\"y\", dist.StudentT(nu, exp_val, omega), obs=y[1:])\n", " # we return last `level` and last `s` for custom forecasting\n", " return last_level, last_s\n", " else: # forecasting\n", " exp_val = exp_val[y.shape[0] - 1:]\n", " assert exp_val.shape[0] == future\n", " omega = sigma * exp_val ** powx + offset_sigma\n", " numpyro.sample(\"y\", dist.StudentT(nu, exp_val, omega))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that all prior parameters are retrieved from [this file](https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/rlgtcontrol.R) in the original source." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we want to choose a good value for `seasonality`. Following [the demo in Rlgt](https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/demo/lynx.R), we will set `seasonality=38`. Indeed, this value can be guessed by looking at the plot of the training data, where the second order seasonality effect has a periodicity around $40$ years. Note that $38$ is also one of the highest-autocorrelation lags." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Lag values sorted according to their autocorrelation values:\n", "\n", "[ 0 67 57 38 68 1 29 58 37 56 28 10 19 39 66 78 47 77 9 79 48 76 30 18\n", " 20 11 46 59 69 27 55 36 2 8 40 49 17 21 75 12 65 45 31 26 7 54 35 41\n", " 50 3 22 60 70 16 44 13 6 25 74 53 42 32 23 43 51 4 15 14 34 24 5 52\n", " 73 64 33 71 72 61 63 62]\n" ] } ], "source": [ "print(\"Lag values sorted according to their autocorrelation values:\\n\") \n", "print(np.argsort(autocorrelation(y_train))[::-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let us run $4$ MCMC chains (using the No-U-Turn Sampler algorithm) with $5000$ warmup steps and $5000$ sampling steps per each chain. The returned value will be a collection of $20000$ samples." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " coef_trend 35.90 141.51 13.49 -89.69 160.01 1253.52 1.00\n", " init_s[0] 93.88 115.72 68.03 -48.20 267.54 563.41 1.01\n", " init_s[1] -20.74 68.27 -22.57 -132.33 86.30 5638.05 1.00\n", " init_s[2] 31.34 91.71 21.30 -109.71 172.00 5751.88 1.00\n", " init_s[3] 126.35 123.06 109.89 -67.25 304.43 5308.82 1.00\n", " init_s[4] 449.81 249.57 407.17 72.60 799.35 3962.26 1.00\n", " init_s[5] 1192.26 459.23 1124.14 481.37 1836.72 2008.70 1.00\n", " init_s[6] 2013.44 656.04 1932.91 967.32 2981.27 2277.62 1.00\n", " init_s[7] 3725.03 1110.80 3613.86 1992.54 5409.12 1860.04 1.00\n", " init_s[8] 2606.46 840.53 2476.63 1328.71 3882.42 1682.46 1.01\n", " init_s[9] 956.78 426.92 899.53 306.73 1601.91 3984.62 1.00\n", " init_s[10] 50.27 105.87 36.33 -106.86 205.88 5185.16 1.00\n", " init_s[11] -0.19 56.79 -2.60 -77.66 75.06 1584.26 1.00\n", " init_s[12] -7.15 67.90 -10.02 -106.90 98.28 708.60 1.00\n", " init_s[13] 66.97 100.33 47.12 -74.98 216.03 5203.29 1.00\n", " init_s[14] 335.98 259.67 276.09 -12.78 709.36 3362.87 1.00\n", " init_s[15] 967.16 400.29 904.31 385.98 1558.11 849.66 1.01\n", " init_s[16] 1271.75 465.74 1209.79 557.99 1961.61 2849.68 1.00\n", " init_s[17] 1386.30 547.93 1284.83 571.16 2217.53 2725.54 1.00\n", " init_s[18] 613.97 307.38 550.73 175.59 1070.71 3319.73 1.00\n", " init_s[19] 16.35 92.00 3.67 -119.55 147.91 1989.23 1.00\n", " init_s[20] -30.11 66.05 -23.57 -144.37 63.33 4514.26 1.00\n", " init_s[21] -16.96 47.14 -6.20 -93.59 42.93 1696.88 1.00\n", " init_s[22] -0.60 43.07 -0.91 -65.73 61.91 3235.73 1.00\n", " init_s[23] 39.57 83.49 24.38 -82.75 165.80 5226.04 1.00\n", " init_s[24] 529.39 330.23 475.21 10.02 986.28 4837.91 1.00\n", " init_s[25] 942.35 457.37 864.48 277.80 1594.75 3319.35 1.00\n", " init_s[26] 1855.40 912.22 1731.51 765.43 2798.29 632.94 1.01\n", " init_s[27] 1308.19 492.55 1231.14 542.53 1982.89 715.91 1.00\n", " init_s[28] 217.56 172.60 181.49 -27.86 469.58 4992.26 1.00\n", " init_s[29] -11.99 82.70 -19.44 -142.68 108.59 2410.68 1.00\n", " init_s[30] -1.89 87.65 -8.74 -136.86 121.59 5177.89 1.00\n", " init_s[31] -39.47 75.38 -37.94 -170.10 70.53 283.89 1.01\n", " init_s[32] -9.67 86.13 -19.10 -141.70 119.14 5254.49 1.00\n", " init_s[33] 120.82 134.18 100.77 -82.96 317.41 4123.89 1.00\n", " init_s[34] 504.63 302.64 451.40 90.37 929.83 1448.21 1.00\n", " init_s[35] 1101.46 463.20 1026.01 443.47 1797.68 725.92 1.01\n", " init_s[36] 1891.59 666.78 1793.22 884.27 2871.91 1886.95 1.00\n", " init_s[37] 1433.02 573.39 1352.83 525.74 2222.67 463.54 1.01\n", " level_sm 0.00 0.00 0.00 0.00 0.00 3763.91 1.00\n", " nu 12.25 4.73 12.72 5.40 20.00 6521.37 1.00\n", " offset_sigma 33.50 30.53 25.27 0.00 72.23 5231.03 1.00\n", " pow_season 0.08 0.04 0.08 0.01 0.14 1244.31 1.01\n", " pow_trend_beta 0.26 0.18 0.24 0.00 0.51 135.08 1.02\n", " powx 0.63 0.14 0.62 0.42 0.86 358.39 1.01\n", " s_sm 0.08 0.09 0.06 0.00 0.20 637.77 1.01\n", " sigma 9.24 9.60 6.29 0.33 20.04 2796.07 1.00\n", "\n", "Number of divergences: 4402\n", "CPU times: user 1min 15s, sys: 219 ms, total: 1min 16s\n", "Wall time: 48.6 s\n" ] } ], "source": [ "%%time\n", "kernel = NUTS(sgt)\n", "mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)\n", "mcmc.run(random.PRNGKey(2), y_train, seasonality=38)\n", "mcmc.print_summary()\n", "samples = mcmc.get_samples()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Forecasting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given `samples` from `mcmc`, we want to do forecasting for the testing dataset `y_test`. First, we will make some utilities to do forecasting given a sample. Note that to retrieve the last `level` and last `s` value, we run the model forward by constraining the latent sites to a sample from the posterior using the `substitute` handler:\n", "\n", "```python\n", "... level, s = substitute(sgt, sample)(y, seasonality)\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Ref: https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/forecast.rlgtfit.R\n", "def sgt_forecast(future, sample, y, level, s):\n", " seasonality = s.shape[0]\n", " moving_sum = np.sum(y[-seasonality:])\n", " pow_trend = 1.5 * sample[\"pow_trend_beta\"] - 0.5\n", " yfs = [0] * (seasonality + future)\n", " for t in range(future):\n", " season = s[0] * level ** sample[\"pow_season\"]\n", " exp_val = level + sample[\"coef_trend\"] * level ** pow_trend + season\n", " exp_val = np.clip(exp_val, a_min=0)\n", " omega = sample[\"sigma\"] * exp_val ** sample[\"powx\"] + sample[\"offset_sigma\"]\n", " yf = numpyro.sample(\"yf[{}]\".format(t), dist.StudentT(\n", " sample[\"nu\"], exp_val, omega))\n", " yf = np.clip(yf, a_min=1e-30)\n", " yfs[t] = yf\n", "\n", " moving_sum = moving_sum + yf - \\\n", " np.where(t >= seasonality, yfs[t - seasonality], y[-seasonality + t])\n", " level_p = moving_sum / seasonality\n", " level_tmp = sample[\"level_sm\"] * level_p + (1 - sample[\"level_sm\"]) * level\n", " level = np.where(level_tmp > 1e-30, level_tmp, level)\n", " # s is repeated instead of being updated\n", " s = np.concatenate([s[1:], s[:1]], axis=0)\n", "\n", "\n", "def forecast(future, rng_key, sample, y, seasonality):\n", " level, s = handlers.substitute(sgt, sample)(y, seasonality)\n", " forecast_model = handlers.seed(sgt_forecast, rng_key)\n", " forecast_trace = handlers.trace(forecast_model).get_trace(\n", " future, sample, y, level, s)\n", " results = [np.clip(forecast_trace[\"yf[{}]\".format(t)][\"value\"], a_min=1e-30)\n", " for t in range(future)]\n", " return np.stack(results, axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we can use [jax.vmap](https://jax.readthedocs.io/en/latest/jax.html#jax.vmap) to get prediction given a collection of samples. This allows us to vectorize the computation across the test dataset which can be dramatically faster as compared to using for-loop to collect predictions per test data point." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "rng_keys = random.split(random.PRNGKey(3), samples[\"nu\"].shape[0])\n", "forecast_marginal = vmap(lambda rng_key, sample: forecast(\n", " len(y_test), rng_key, sample, y_train, seasonality=38))(rng_keys, samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's get sMAPE, root mean square error of the prediction, and visualize the result with the mean prediction and the 90% highest posterior density interval (HPDI)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sMAPE: 62.87, rmse: 1248.86\n" ] } ], "source": [ "y_pred = np.mean(forecast_marginal, axis=0)\n", "sMAPE = np.mean(np.abs(y_pred - y_test) / (y_pred + y_test)) * 200\n", "msqrt = np.sqrt(np.mean((y_pred - y_test) ** 2))\n", "print(\"sMAPE: {:.2f}, rmse: {:.2f}\".format(sMAPE, msqrt))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8, 4))\n", "plt.plot(lynx[\"time\"], data)\n", "t_future = lynx[\"time\"][80:]\n", "hpd_low, hpd_high = hpdi(forecast_marginal)\n", "plt.plot(t_future, y_pred, lw=2)\n", "plt.fill_between(t_future, hpd_low, hpd_high, alpha=0.3)\n", "plt.title(\"Forecasting lynx dataset with SGT model (90% HPDI)\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can observe, the model has been able to learn both the first and second order seasonality effects, i.e. a cyclical pattern with a periodicity of around 10, as well as spikes that can be seen once every 40 or so years. Moreover, we not only have point estimates for the forecast but can also use the uncertainty estimates from the model to bound our forecasts. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Forecasting using Predictive" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NumPyro provides a convenient utility [Predictive](http://num.pyro.ai/en/stable/utilities.html#numpyro.infer.util.Predictive) to get predictive distribution. Let's see how to use it to get forecasting values.\n", "\n", "Notice that in the `sgt` model defined above, there is a keyword `future` which controls the execution of the model - depending on whether `future > 0` or `future == 0`. The following code predicts the last 34 values from the original time-series." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictive = Predictive(sgt, samples, return_sites=[\"y\"])\n", "forecast2 = predictive(random.PRNGKey(4), y_train, seasonality=38, future=34)[\"y\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the result to verify that we get the expected one." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8, 4))\n", "plt.plot(lynx[\"time\"], data)\n", "t_future = lynx[\"time\"][80:]\n", "hpd_low, hpd_high = hpdi(forecast2)\n", "plt.plot(t_future, np.mean(forecast2, axis=0), lw=2)\n", "plt.fill_between(t_future, hpd_low, hpd_high, alpha=0.3)\n", "plt.title(\"Forecasting lynx dataset with SGT model (90% HPDI)\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Acknowledgements\n", "\n", "We would like to thank Slawek Smyl for many helpful resources and suggestions. Fast inference would not have been possible without the support of JAX and the XLA teams, so we would like to thank them for providing such a great open-source platform for us to build on, and for their responsiveness in dealing with our feature requests and bug reports." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "[1] `Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications`,
    \n", "Slawek Smyl, Christoph Bergmeir, Erwin Wibowo, To Wang Ng, Trustees of Columbia University" ] } ], "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { "display_name": "Python (pydata)", "language": "python", "name": "pydata" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }