{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Gaussian filtering and smoothing (continuous-discrete)\n", "\n", "Provided is an example of linear state-space models on which one can perform Bayesian filtering and smoothing in order to obtain\n", "a posterior distribution over a latent state trajectory based on noisy observations.\n", "In order to understand the theory behind these methods in detail we refer to [1] and [2].\n", "\n", "**References**:\n", "> [1] Särkkä, Simo, and Solin, Arno. Applied Stochastic Differential Equations. Cambridge University Press, 2019. \n", ">\n", "> [2] Särkkä, Simo. Bayesian Filtering and Smoothing. Cambridge University Press, 2013." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "import probnum as pn\n", "from probnum import filtsmooth, randvars, randprocs\n", "from probnum.problems import TimeSeriesRegressionProblem" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "rng = np.random.default_rng(seed=123)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_125705/236124620.py:5: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`\n", " set_matplotlib_formats(\"pdf\", \"svg\")\n" ] } ], "source": [ "# Make inline plots vector graphics instead of raster graphics\n", "%matplotlib inline\n", "from IPython.display import set_matplotlib_formats\n", "\n", "set_matplotlib_formats(\"pdf\", \"svg\")\n", "\n", "# Plotting\n", "import matplotlib.pyplot as plt\n", "import matplotlib.gridspec as gridspec\n", "\n", "plt.style.use(\"../../probnum.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **Linear Continuous-Discrete** State-Space Model: Ornstein-Uhlenbeck Process" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we have a look at **continuous** dynamics. We assume that there is a continuous process that defines the dynamics of our latent space from which we collect discrete linear-Gaussian measurements (as above). Only the dynamics model becomes continuous. In particular, we formulate the dynamics as a stochastic process in terms of a linear time-invariant stochastic differential equation (LTISDE). We refer to [1] for more details.\n", "Consider matrices $\\boldsymbol{F} \\in \\mathbb{R}^{d \\times d}$, $\\boldsymbol{L} \\in \\mathbb{R}^{s \\times d}$ and $H \\in \\mathbb{R}^{m \\times d}$ where $d$ is the state dimension and $m$ is the dimension of the measurements.\n", "We define the following **continuous-discrete** state-space model:\n", "\n", "Let $x(t_0) \\sim \\mathcal{N}(\\mu_0, \\Sigma_0)$.\n", "\n", "$$\n", "\\begin{align}\n", " d\\boldsymbol{x} &= \\boldsymbol{F} \\, \\boldsymbol{x} \\, dt + \\boldsymbol{L} \\, d \\boldsymbol{\\omega} \\\\\n", " \\boldsymbol{y}_k &\\sim \\mathcal{N}(\\boldsymbol{H} \\, \\boldsymbol{x}(t_k), \\boldsymbol{R}), \\qquad k = 1, \\dots, K\n", "\\end{align}\n", "$$\n", "\n", "where $\\boldsymbol{\\omega} \\in \\mathbb{R}^s$ denotes a vector of driving forces (often Brownian Motion).\n", "\n", "Note that this can be generalized to a linear time-varying state-space model, as well. Then $\\boldsymbol{F}$ is a function $\\mathbb{T} \\rightarrow \\mathbb{R}^{d \\times d}$,\n", "$\\boldsymbol{L}$ is a function $\\mathbb{T} \\rightarrow \\mathbb{R}^{s \\times d}$, and $H$ is a function $\\mathbb{T} \\rightarrow \\mathbb{R}^{m \\times d}$ where $\\mathbb{T}$ is the \"time dimension\". In the following example, however, we consider a LTI SDE, namely, the Ornstein-Uhlenbeck Process from which we observe discrete linear Gaussian measurements." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define State-Space Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### I. Continuous Dynamics Model: Linear, Time-Invariant Stochastic Differential Equation (LTISDE)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "state_dim = 1\n", "observation_dim = 1" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "delta_t = 0.2\n", "# Define Linear, time-invariant Stochastic Differential Equation that models\n", "# the (scalar) Ornstein-Uhlenbeck Process\n", "drift_constant = 0.21\n", "dispersion_constant = np.sqrt(0.5)\n", "drift = -drift_constant * np.eye(state_dim)\n", "force = np.zeros(state_dim)\n", "dispersion = dispersion_constant * np.eye(state_dim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The _continuous_ counterpart to the discrete LTI Gaussian model is provided via the `LTISDE` class." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Create dynamics model\n", "dynamics_model = randprocs.markov.continuous.LTISDE(\n", " drift_matrix=drift,\n", " force_vector=force,\n", " dispersion_matrix=dispersion,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### II. Discrete Measurement Model: Linear, Time-Invariant Gaussian Measurements" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "measurement_marginal_variance = 0.1\n", "measurement_matrix = np.eye(observation_dim, state_dim)\n", "measurement_noise_matrix = measurement_marginal_variance * np.eye(observation_dim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As above, the measurement model is discrete, LTI Gaussian. Only the dymanics are continuous (i.e. continuous-discrete)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "noise = randvars.Normal(mean=np.zeros(observation_dim), cov=measurement_noise_matrix)\n", "measurement_model = randprocs.markov.discrete.LTIGaussian(\n", " transition_matrix=measurement_matrix,\n", " noise=noise,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### III. Initial State Random Variable" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "mu_0 = 10.0 * np.ones(state_dim)\n", "sigma_0 = np.eye(state_dim)\n", "initial_state_rv = randvars.Normal(mean=mu_0, cov=sigma_0)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "prior_process = randprocs.markov.MarkovProcess(\n", " transition=dynamics_model, initrv=initial_state_rv, initarg=0.0\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate Data for the State-Space Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, sample both latent states and noisy observations from the specified state-space model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "time_grid = np.arange(0.0, 10.0, step=delta_t)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "latent_states, observations = randprocs.markov.utils.generate_artificial_measurements(\n", " rng=rng,\n", " prior_process=prior_process,\n", " measmod=measurement_model,\n", " times=time_grid,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "regression_problem = TimeSeriesRegressionProblem(\n", " observations=observations,\n", " locations=time_grid,\n", " measurement_models=[measurement_model] * len(time_grid),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Kalman Filtering" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In fact, since we still consider a **linear** model, we can apply Kalman Filtering in this case again.\n", "According to Section 10 in [1], the moments of the filtering posterior in the continuous-discrete case are solutions to linear differential equations, which `probnum` solves for us when invoking the `.filtsmooth(...)` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### I. Kalman Filter" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "kalman_filter = filtsmooth.gaussian.Kalman(prior_process)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### II. Perform Kalman Filtering + Rauch-Tung-Striebel Smoothing" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "state_posterior, _ = kalman_filter.filtsmooth(regression_problem)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The method `filtsmooth` returns a `KalmanPosterior` object which provides convenience functions for e.g. sampling and prediction.\n", "We can also extract the just computed posterior smoothing state variables by querying the `.state_rvs` property. \n", "This yields a list of Gaussian Random Variables from which we can extract the statistics in order to visualize them." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "grid = np.linspace(0, 11, 500)\n", "\n", "posterior_state_rvs = state_posterior(\n", " grid\n", ") # List of Normal Random Variables\n", "posterior_state_means = posterior_state_rvs.mean.squeeze() # Shape: (num_time_points, )\n", "posterior_state_covs = posterior_state_rvs.cov # Shape: (num_time_points, )\n", "\n", "samples = state_posterior.sample(rng=rng, size=3, t=grid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize Results" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [ "nbsphinx-thumbnail" ] }, "outputs": [ { "data": { "application/pdf": "\n", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-02-09T12:38:29.517494\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "state_fig = plt.figure()\n", "\n", "ax = state_fig.add_subplot()\n", "\n", "# Plot means\n", "ax.plot(grid, posterior_state_means, label=\"posterior mean\")\n", "\n", "# Plot samples\n", "for smp in samples:\n", " ax.plot(\n", " grid,\n", " smp[:, 0],\n", " color=\"gray\",\n", " alpha=0.75,\n", " linewidth=1,\n", " linestyle=\"dashed\",\n", " label=\"sample\",\n", " )\n", "\n", "\n", "# Plot marginal standard deviations\n", "std_x = np.sqrt(np.abs(posterior_state_covs)).squeeze()\n", "ax.fill_between(\n", " grid,\n", " posterior_state_means - 1.96 * std_x,\n", " posterior_state_means + 1.96 * std_x,\n", " alpha=0.2,\n", " label=\"1.96 marginal stddev\",\n", ")\n", "ax.scatter(time_grid, observations, marker=\".\", label=\"measurements\")\n", "# Add labels etc.\n", "ax.set_xlabel(\"t\")\n", "ax.set_title(r\"$x$\")\n", "\n", "# These two lines just remove duplicate labels (caused by the samples) from the legend\n", "handles, labels = ax.get_legend_handles_labels()\n", "by_label = dict(zip(labels, handles))\n", "\n", "ax.legend(\n", " by_label.values(), by_label.keys(), loc=\"center left\", bbox_to_anchor=(1, 0.5)\n", ")\n", "\n", "\n", "state_fig.tight_layout()" ] } ], "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }