{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Probabilistic Programming 4: Mixture and dynamic models\n", "\n", "#### Goal \n", " - Learn how to infer the parameters of a Gaussian mixture model using variational inference in a probabilistic programming language.\n", " - Learn how to infer states and noise in linear Gaussian state-space model using variational inference in a probabilistic programming language.\n", "\n", "#### Materials \n", " - Mandatory\n", " - This notebook\n", " - Lecture notes on latent variable models\n", " - Lecture notes on dynamical models\n", " - Optional\n", " - [Review of latent variable models](https://doi.org/10.1146/annurev-statistics-022513-115657)\n", " - [Bayesian Filtering & Smoothing](https://www.cambridge.org/core/books/bayesian-filtering-and-smoothing/C372FB31C5D9A100F8476C1B23721A67)\n", " - [Differences between Julia and Matlab / Python](https://docs.julialang.org/en/v1/manual/noteworthy-differences/index.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that none of the material below is new. The point of the Probabilistic Programming sessions is to solve practical problems so that the concepts from Bert's lectures become less abstract." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "using Pkg\n", "Pkg.activate(\"../../../lessons/\")\n", "Pkg.instantiate();\n", "IJulia.clear_output();" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "using JLD\n", "using Statistics\n", "using LinearAlgebra\n", "using Distributions\n", "using RxInfer\n", "using ColorSchemes\n", "using LaTeXStrings\n", "using Plots\n", "default(label=\"\", grid=false, linewidth=3, margin=10Plots.pt)\n", "include(\"../scripts/clusters.jl\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Problem: Stone Tools\n", "\n", "Archeologists have asked for your help in analyzing data on tools made of stone. It is believed that primitive humans created tools by striking stones with others. During this process, the stone loses flakes, which have been preserved. The archeologists have recovered these flakes from various locations and time periods and want to know whether this stone tool shaping process has improved over the centuries." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data\n", "\n", "The data is available from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/StoneFlakes). Each instance represents summary information of the stone flakes for a particular site. We will be using the attributes _flaking angle_ (FLA) and the _proportion of the dorsal surface worked_ (PROZD) for now." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "data = load(\"../datasets/stone_flakes.jld\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I've done some pre-processing on the data set, namely [z-scoring](https://nl.wikipedia.org/wiki/Z-score) and removing two outliers. This reduces the scale of the attributes which helps numerical stability during optimization. Now let's visualize the data with a scatterplot." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scatter(data[\"observations\"][:,1], \n", " data[\"observations\"][:,2], \n", " label=\"\", \n", " xlabel=\"Proportion of worked dorsal surface (PROZD)\",\n", " ylabel=\"Flaking angle (FLA)\",\n", " size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model specification\n", "\n", "We will be clustering this data with a Gaussian mixture model, to see if we can identify clear types of stone tools. The generative model for a Gaussian mixture consists of:\n", "\n", "$$ p(X, z, \\pi, \\mu, \\Lambda) =\\ \\underbrace{p(X \\mid z, \\mu, \\Lambda)}_{\\text{likelihood}}\\ \\times \\ \\underbrace{p(z \\mid \\pi)}_{\\text{prior latent variables}} \\ \\times \\ \\underbrace{p(\\mu \\mid \\Lambda)\\ p(\\Lambda)\\ p(\\pi)}_{\\text{prior parameters}}$$\n", "\n", "with the likelihood of observation $X_i$ being a Gaussian raised to the power of the latent assignment variables $z$\n", "\n", "$$ p(X_i \\mid z, \\mu, \\Lambda) = \\prod_{k=1}^{K} \\mathcal{N}(X_i \\mid \\mu_k, \\Lambda_k^{-1})^{z_i = k}$$\n", "\n", "the prior for each latent variable $z_i$ being a Categorical distribution\n", "\n", "$$ p(z_i \\mid \\pi) = \\text{Categorical}(z_i \\mid \\pi) $$\n", "\n", "and priors for the parameters being\n", "\n", "$$ \\begin{aligned} p(\\Lambda_k) =&\\ \\text{Wishart}(\\Lambda_k \\mid V_0, n_0) \\qquad &\\text{for all}\\ k , \\\\ p(\\mu_k \\mid \\Lambda_k) =&\\ \\mathcal{N}(\\mu_k \\mid m_0, \\Lambda_k^{-1}) \\qquad &\\text{for all}\\ k , \\\\ p(\\pi) =&\\ \\text{Dirichlet}(\\pi \\mid a_0) \\quad . \\end{aligned}$$\n", "\n", "We can implement these equations nearly directly in ReactiveMP." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Data dimensionality\n", "num_features = size(data[\"observations\"],2)\n", "\n", "# Sample size\n", "num_samples = size(data[\"observations\"],1)\n", "\n", "# Number of mixture components\n", "num_components = 3;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mixture models can be sensitive to initialization, so we are going to specify the prior parameters explicitly." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dict{Symbol, Any} with 3 entries:\n", " :μ => [1.0 0.0 -1.0; -1.0 0.0 1.0]\n", " :π => [1.0, 1.0, 1.0]\n", " :Λ => ([1.0 0.0; 0.0 1.0;;; 1.0 0.0; 0.0 1.0;;; 1.0 0.0; 0.0 1.0], 2)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Identity matrix (convenience variable)\n", "Id = diagm(ones(num_features));\n", "\n", "# Prior scale matrices\n", "V0 = cat(Id, Id, Id, dims=3)\n", "\n", "# Prior degrees of freedom \n", "n0 = num_features\n", "\n", "# Prior means\n", "m0 = [ 1.0 0.0 -1.0;\n", " -1.0 0.0 1.0];\n", "\n", "# Prior concentration parameters\n", "a0 = ones(num_components);\n", "\n", "# Pack into dictionary\n", "prior_params = Dict{Symbol,Any}(:Λ => (V0,n0), :μ => m0, :π => a0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now to specify the full model." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "@model function GMM(prior_params; K=1, N=1)\n", " \n", " # Allocate variables\n", " X = datavar(Vector{Float64}, N)\n", " z = randomvar(N) \n", " μ = randomvar(K)\n", " Λ = randomvar(K)\n", " \n", " # Unpack prior parameters\n", " V0 = prior_params[:Λ][1]\n", " n0 = prior_params[:Λ][2]\n", " m0 = prior_params[:μ]\n", " a0 = prior_params[:π]\n", " \n", " # Component parameters\n", " for k in 1:K\n", " Λ[k] ~ Wishart(n0, V0[:,:,k])\n", " μ[k] ~ MvNormalMeanPrecision(m0[:,k], Λ[k])\n", " end\n", " \n", " # Mixture weights\n", " π ~ Dirichlet(a0)\n", " \n", " cmeans = tuple(μ...)\n", " cprecs = tuple(Λ...)\n", " \n", " for i in 1:N\n", " \n", " # Latent assignment variable\n", " z[i] ~ Categorical(π)\n", " \n", " # Mixture distribution\n", " X[i] ~ NormalMixture(z[i], cmeans, cprecs) where { q = MeanField() }\n", " \n", " end\n", " return X,z,π,μ,Λ\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set up the inference procedure." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:05\u001b[39m\n" ] }, { "data": { "text/plain": [ "Inference results:\n", " Posteriors | available for (μ, π, Λ, z)\n", " Free Energy: | Real[217.442, 202.111, 200.341, 199.685, 199.169, 198.64, 198.078, 197.534, 197.088, 196.743 … 193.448, 193.448, 193.448, 193.448, 193.448, 193.448, 193.448, 193.448, 193.448, 193.448]\n" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Map data to list of vectors\n", "observations = [data[\"observations\"][i,:] for i in 1:num_samples]\n", "\n", "# Set variational distribution factorization\n", "constraints = @constraints begin\n", " q(z,π,μ,Λ) = q(z)q(π)q(μ)q(Λ) \n", "end\n", "\n", "# Initialize variational distributions\n", "initmarginals = (\n", " π = Dirichlet(a0),\n", " μ = [MvNormalMeanCovariance(m0[:,k], Id) for k in 1:num_components],\n", " Λ = [Wishart(n0, V0[:,:,k]) for k in 1:num_components]\n", ")\n", "\n", "# Iterations of variational inference\n", "num_iters = 100\n", "\n", "# Variational inference\n", "results = inference(\n", " model = GMM(prior_params, K=num_components, N=num_samples),\n", " data = (X = observations,),\n", " constraints = constraints,\n", " initmarginals = initmarginals,\n", " iterations = num_iters,\n", " free_energy = true,\n", " showprogress = true,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alright, we're done. Let's track the evolution of free energy over iterations." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot(1:num_iters, \n", " results.free_energy, \n", " color=\"black\", \n", " xscale=:log10,\n", " xlabel=\"Number of iterations\", \n", " ylabel=\"Free Energy\", \n", " size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That looks like it is nicely decreasing. Let's now visualize the cluster on top of the observations." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Extract mixture weights\n", "π_hat = mean(results.posteriors[:π][num_iters])\n", "\n", "# Extract component means\n", "μ_hat = mean.(results.posteriors[:μ][num_iters])\n", "\n", "# Extract covariance matrices\n", "Σ_hat = inv.(mean.(results.posteriors[:Λ][num_iters]))\n", "\n", "# Select dimensions to plot\n", "xlims = [minimum(data[\"observations\"][:,1])-1, maximum(data[\"observations\"][:,1])+1]\n", "ylims = [minimum(data[\"observations\"][:,2])-1, maximum(data[\"observations\"][:,2])+1]\n", "\n", "# Plot data and overlay estimated posterior probabilities\n", "plot_clusters(data[\"observations\"][:, 1:2], \n", " μ=μ_hat, \n", " Σ=Σ_hat, \n", " xlims=xlims, \n", " ylims=ylims,\n", " xlabel=\"Proportion of worked dorsal surface (PROZD)\",\n", " ylabel=\"Flaking angle (FLA)\",\n", " colors=[:reds, :blues, :greens],\n", " figsize=(800,500))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That doesn't look bad. The three Gaussians nicely cover all samples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### Exercise\n", "\n", "Play around with the number of components. Can you get an equally good coverage with just 2 components? What if you had 4?\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also plot the evolution of the parameters over iterations of the variational inference procedure." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Extract mean and standard deviation from \n", "mean_π_iters = cat(mean.(results.posteriors[:π])..., dims=2)\n", "vars_π_iters = cat(var.( results.posteriors[:π])..., dims=2)\n", "\n", "plot(1:num_iters, \n", " mean_π_iters', \n", " ribbon=vars_π_iters', \n", " xscale=:log10,\n", " xlabel=\"Number of iterations\", \n", " ylabel=\"Free Energy\", \n", " size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### Exercise\n", "\n", "Plot the evolution of one of the component means over iterations of variational inference, including its variance.\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Problem: shaking buildings\n", "\n", "Suppose you are contacted to estimate how resistant a building is to shaking caused by minor earthquakes. You decide to model the building as a [mass-spring-damper](https://en.wikipedia.org/wiki/Mass-spring-damper_model) system, described by:\n", "\n", "$$m \\ddot{x} + c\\dot{x} + kx = w \\, ,$$\n", "\n", "where $m$ corresponds to the mass of the building, $c$ is friction and $k$ the stiffness of the building's main supports. You don't know the external force acting upon the building and decide to model it as white noise $w$. In essence, this means you think the building will be pushed to the left as strongly on average as it will be pushed to the right. A simple discretization scheme with substituted variable $z = [x \\ \\dot{x}]$ and time-step $\\Delta t$ yields:\n", "\n", "$$z_{k} = \\underbrace{\\begin{bmatrix} 1 & \\Delta t \\\\ \\frac{-k}{m}\\Delta t & \\frac{-c}{m}\\Delta t + 1 \\end{bmatrix}}_{A} z_{k-1} + q_k \\, ,$$\n", "\n", "where $q_k \\sim \\mathcal{N}(0, Q)$ with covariance matrix $Q$. \n", "\n", "You place a series of sensors on the building that measure the displacement:\n", "\n", "$$ y_k = \\underbrace{\\begin{bmatrix} 1 & 0 \\end{bmatrix}}_{C} z_k + r_k \\, $$\n", "\n", "where the measurement noise is white as well: $r_k \\sim \\mathcal{N}(0, \\sigma^2)$. You have a good estimate of the variance $\\sigma^2$ from previous sensor calibrations and decide to consider it a known variable." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load data from file\n", "data = load(\"../datasets/shaking_buildings.jld\")\n", "\n", "# Data\n", "states = data[\"states\"]\n", "observations = data[\"observations\"]\n", "\n", "# Parameters\n", "mass = data[\"m\"]\n", "friction = data[\"c\"]\n", "stiffness = data[\"k\"]\n", "\n", "# Measurement noise variance\n", "σ = data[\"σ\"]\n", "\n", "# Time\n", "Δt = data[\"Δt\"]\n", "T = length(observations)\n", "time = range(1,step=Δt,length=T)\n", "\n", "plot(time, states[1,:], color=\"red\", label=\"states\", xlabel=\"time (sec)\", ylabel=\"train position\")\n", "scatter!(time, observations, color=\"black\", label=\"observations\", legend=:topleft, size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model specification\n", "\n", "Following the steps from the [lecture on Dynamic Systems](https://nbviewer.org/github/bertdv/BMLIP/blob/master/lessons/notebooks/Dynamic-Models.ipynb), we can derive the following probabilistic state-space model:\n", "\n", "$$\\begin{aligned} p(z_k \\mid z_{k-1}) =&\\ \\mathcal{N}(z_k \\mid A z_{k-1}, Q)\\\\ p(y_k \\mid z_k) =&\\ \\mathcal{N}(y_k \\mid C z_k, \\sigma^2) \\, . \\end{aligned}$$\n", "\n", "For now, we will use a simple structure for the process noise covariance matrix, e.g. $Q = I$. If we consider a Gaussian prior distribution for the initial state\n", "\n", "$$ p(z_0) = \\mathcal{N}(m_0, S_0) \\, ,$$\n", "\n", "we obtain a complete generative model:\n", "\n", "$$\\begin{aligned} \\underbrace{p(y_{1:T}, z_{0:T})}_{\\text{generative model}} = \\underbrace{p(z_0)}_{\\text{prior}} \\, \\prod_{k=1}^T \\, \\underbrace{p(y_k \\mid z_k)}_{\\text{likelihood}} \\, \\underbrace{p(z_k \\mid z_{k-1})}_{\\text{state transition}} \\end{aligned}$$\n", "\n", "To define this model in ReactiveMP, we must start with the process matrices:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×2 Matrix{Float64}:\n", " 1.0 0.0\n", " 0.0 1.0" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Transition matrix\n", "A = [1 Δt; -stiffness/mass*Δt -friction/mass*Δt+1]\n", "\n", "# Emission matrix\n", "C = [1.0, 0.0]\n", "\n", "# Set process noise covariance matrix\n", "Q = diagm(ones(2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define a linear Gaussian dynamical system with only the states as unknown variables:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Inference results:\n", " Posteriors | available for (z_0, z)\n", " Free Energy: | Real[434.098]\n" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@model function LGDS(prior_params, A,C,Q, σ; T=1)\n", " \"State estimation in linear Gaussian dynamical system\"\n", " \n", " z = randomvar(T)\n", " y = datavar(Float64,T)\n", " \n", " # Prior state\n", " z_0 ~ MvNormalMeanCovariance(prior_params[:z0][1], prior_params[:z0][2])\n", " \n", " z_kmin1 = z_0\n", " for k in 1:T\n", " \n", " # State transition\n", " z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q)\n", " \n", " # Likelihood\n", " y[k] ~ NormalMeanVariance(dot(C, z[k]), σ^2)\n", " \n", " # Update recursive aux\n", " z_kmin1 = z[k]\n", " \n", " end\n", " return y, z\n", "end\n", "\n", "# Initial state prior\n", "prior_params = Dict(:z0 => (zeros(2), diageye(2)))\n", "\n", "(posteriors,_) = inference(\n", " model = LGDS(prior_params, A,C,Q, σ, T=T),\n", " data = (y = [observations[k] for k in 1:T],),\n", " free_energy = true,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's extract the inferred states and visualize them." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m_z = cat(mean.(posteriors[:z])...,dims=2)\n", "v_z = cat(var.( posteriors[:z])...,dims=2)\n", "\n", "plot(time, states[1,:], color=\"red\", label=\"states\", xlabel=\"time (sec)\", ylabel=\"train position\")\n", "plot!(time, m_z[1,:], color=\"blue\", ribbon=v_z[1,:], label=\"inferred\")\n", "scatter!(time, observations, color=\"black\", alpha=0.2, label=\"observations\", legend=:bottomright, size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mmmh... the inferred states are not smooth at all. This is most likely due to our process noise covariance matrix not being calibrated. So can we improve? \n", "\n", "---\n", "\n", "Of course, as Bayesians, we can just treat $Q$ as an unknown random variable and infer its posterior distribution. Adjusting the model is straightforward. The probabilistic state-space model becomes:\n", "\n", "$$\\begin{aligned} p(z_k \\mid z_{k-1},Q) =&\\ \\mathcal{N}(z_k \\mid A z_{k-1}, Q)\\\\ p(y_k \\mid z_k) =&\\ \\mathcal{N}(y_k \\mid C z_k, \\sigma^2) \\, , \\end{aligned}$$\n", "\n", "with priors\n", "\n", "$$\\begin{aligned} p(Q) =&\\ \\mathcal{W}^{-1}(\\nu, \\Lambda) \\\\ p(z_0) =&\\ \\mathcal{N}(m_0, S_0) \\, . \\end{aligned}$$\n", "\n", "The $\\mathcal{W}^{-1}$ represents an inverse-Wishart distribution with degrees-of-freedom $\\nu$ and scale matrix $\\Lambda$. This will give the following generative model:\n", "\n", "$$\\begin{aligned} \\underbrace{p(y_{1:T}, z_{0:T}, Q)}_{\\text{generative model}} = \\underbrace{p(z_0)p(Q)}_{\\text{priors}} \\, \\prod_{k=1}^T \\, \\underbrace{p(y_k \\mid z_k)}_{\\text{likelihood}} \\, \\underbrace{p(z_k \\mid z_{k-1},Q)}_{\\text{state transition}}\n", "\\end{aligned}$$\n", "\n", "Our model definition in ReactiveMP is only slightly larger:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "@model function LGDS_Q(prior_params, A,C, σ; T=1)\n", " \"State estimation in a linear Gaussian dynamical system with unknown process noise\"\n", " \n", " z = randomvar(T)\n", " y = datavar(Float64,T)\n", " \n", " # Prior state\n", " z_0 ~ MvNormalMeanCovariance(prior_params[:z0][1], prior_params[:z0][2])\n", " \n", " # Process noise covariance matrix\n", " Q ~ InverseWishart(prior_params[:Q][1], prior_params[:Q][2])\n", " \n", " z_kmin1 = z_0\n", " for k in 1:T\n", " \n", " # State transition\n", " z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q)\n", " \n", " # Likelihood\n", " y[k] ~ NormalMeanVariance(dot(C, z[k]), σ^2)\n", " \n", " # Update recursive aux\n", " z_kmin1 = z[k]\n", " \n", " end\n", " return y, z, Q\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### Exercise\n", "\n", "Think of what might be appropriate prior parameters for the Inverse-Wishart distribution. Should its mean be high or low here?\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I have chosen the following set of prior parameters:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dict{Symbol, Tuple{Any, Matrix{Float64}}} with 2 entries:\n", " :Q => (10, [1.0 0.0; 0.0 1.0])\n", " :z0 => ([0.0, 0.0], [1.0 0.0; 0.0 1.0])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define prior parameters\n", "prior_params = Dict(:z0 => (zeros(2), diageye(2)),\n", " :Q => (10, diageye(2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The variational inference procedure for estimating states and the process noise covariance matrix simultaneously requires a bit more thought, but is still very straightforward:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02\u001b[39m\n" ] }, { "data": { "text/plain": [ "Inference results:\n", " Posteriors | available for (z_0, Q, z)\n", " Free Energy: | Real[425.591, 368.392, 359.522, 356.68, 356.422, 356.522, 356.717, 356.87, 356.851, 356.897 … 357.077, 357.078, 357.078, 357.078, 357.079, 357.079, 357.079, 357.08, 357.08, 357.08]\n" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Iterations of variational inference\n", "num_iters = 100\n", "\n", "# Initialize variational marginal distributions and messages\n", "inits = Dict(:z => MvNormalMeanCovariance(zeros(2), diageye(2)),\n", " :Q => InverseWishart(10, diageye(2)))\n", "\n", "# Define variational distribution factorization\n", "constraints = @constraints begin\n", " q(z_0, z,Q) = q(z_0, z)q(Q)\n", "end\n", "\n", "# Variational inference procedure\n", "results = inference(\n", " model = LGDS_Q(prior_params, A,C, σ, T=T),\n", " data = (y = [observations[k] for k in 1:T],),\n", " constraints = constraints,\n", " iterations = num_iters,\n", " options = (limit_stack_depth = 100,),\n", " initmarginals = inits,\n", " initmessages = inits,\n", " free_energy = true,\n", " showprogress = true,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, let's inspect the free energy to see if we have converged." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot(1:num_iters, \n", " results.free_energy, \n", " color=\"black\", \n", " xscale=:log10,\n", " xlabel=\"Number of iterations\", \n", " ylabel=\"Free Energy\", \n", " size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alright. That looks good. Let's extract the inferred states and visualize." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m_z = cat(mean.(last(results.posteriors[:z]))...,dims=2)\n", "v_z = cat(var.(last(results.posteriors[:z]))...,dims=2)\n", "\n", "plot(time, states[1,:], color=\"red\", label=\"states\", xlabel=\"time (sec)\", ylabel=\"train position\")\n", "plot!(time, m_z[1,:], color=\"blue\", ribbon=v_z[1,:], label=\"inferred\")\n", "scatter!(time, observations, color=\"black\", alpha=0.2, label=\"observations\", legend=:topleft, size=(800,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's much smoother. The free energy of this model ($\\mathcal{F} = 357.08$) is smaller than that of the earlier model with $Q$ set to an identity matrix ($\\mathcal{F} = 434.10$). That means that the added cost of inferring the matrix $Q$ is offset by the increase in performance it provides. \n", "\n", "The error with respect to the true states seems smaller as well, but in practice we of course can't check this.\n", "\n", "Let's inspect the inferred process noise covariance matrix:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×2 Matrix{Float64}:\n", " 0.0600678 0.00535827\n", " 0.00535827 0.145712" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q_MAP = mean(last(results.posteriors[:Q]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We do not have enough data to recover the true process noise covariance matrix exactly, but the result is definitely closer to the truth." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×2 Matrix{Float64}:\n", " 0.00173611 0.0240885\n", " 0.0240885 0.549072" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# True data\n", "Q_true = data[\"Q\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize that for a closer look:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Colorbar limits\n", "clims = (minimum([Q_MAP[:]; Q_true[:]]), maximum([Q_MAP[:]; Q_true[:]]))\n", " \n", "# Plot covariance matrices as heatmaps\n", "p401 = heatmap(Q_MAP, axis=([], false), yflip=true, title=\"Estimated\", clims=clims)\n", "p402 = heatmap(Q_true, axis=([], false), yflip=true, title=\"True\", clims=clims)\n", "plot(p401,p402, layout=(1,2), size=(900,300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### Exercise\n", "\n", "Can you come up with a way to improve the model even further?\n", "\n", "---" ] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "kernelspec": { "display_name": "Julia 1.8.1", "language": "julia", "name": "julia-1.8" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.8.1" } }, "nbformat": 4, "nbformat_minor": 4 }