{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simulation-based calibration\n", "\n", "\n", "[Simulation-based calibration](https://arxiv.org/abs/1804.06788) (SBC) is a method for visually validating Bayesian inferences. SBC is useful for detection of either misspecified models, inaccurate computation or bugs in the \n", "implementation of a probabilistic program. So SBC is not a validation for the inference itself, but the technical aspect of the program." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Consider a generative model:\n", "\n", "\\begin{align}\n", "\\tilde{\\theta} & \\sim P(\\theta) \\\\\n", "\\tilde{y} & \\sim P(y \\mid \\tilde{\\theta}) \\\\\n", "\\{ \\theta_1, \\dots, \\theta_L \\} & \\sim P(\\theta \\mid \\tilde{y})\n", "\\end{align}\n", "\n", "The rank of a prior sample $\\tilde{\\theta}$ in comparison to an *exact* posterior sample $\\{ \\theta_1, \\dots, \\theta_L \\}$: \n", "\n", "\\begin{align}\n", "r(\\{ \\theta_1, \\dots, \\theta_L \\}, \\tilde{\\theta}) = \\sum_l^L \\mathbb{I}(\\theta_l < \\tilde{\\theta}) + 1\n", "\\end{align}\n", "\n", "is a discrete-uniform random variable in $[1, L + 1 ]$ (see the original paper for a proof). Thus we can use this as a testing procedure if our inferences work." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We follow Algorithm 2 from the original paper and implement SBC in Stan." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading required package: ggplot2\n", "Loading required package: StanHeaders\n", "rstan (Version 2.18.2, GitRev: 2e1f913d3ca3)\n", "For execution on a local, multicore CPU with excess RAM we recommend calling\n", "options(mc.cores = parallel::detectCores()).\n", "To avoid recompilation of unchanged Stan programs, we recommend calling\n", "rstan_options(auto_write = TRUE)\n" ] } ], "source": [ "library(rstan)\n", "suppressMessages(library(tidyverse))\n", "options(repr.plot.width = 6, repr.plot.height = 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SBC is essentially pretty simple. For $N$ iterations we run a while loop to create a posterior sample that has at least a target effective sample size $n_{eff}$ (which is set by us). We do so by resampling and thinning until the posterior sample for every iteration reaches the target $n_{eff}$ (within the while loop). Then we compute the rank for every parameter using the sum as defined above. That is it. If the inference worked, the ranks are uniformly distributed." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "sbc <- function(model, data)\n", "{\n", " ranks <- matrix(0, N, 2, dimnames = list(NULL, c(\"mu\", \"sigma\")))\n", " for (n in seq(N))\n", " { \n", " thin <- init_thin\n", " while (thin < max_thin) \n", " {\n", " fit <- suppressWarnings(\n", " sampling(model, data = data,\n", " chains = 1, iter = 2 * thin * L,\n", " thin = thin, control = list(adapt_delta = 0.99), refresh = 0)\n", " ) \n", " n_eff <- summary(fit)$summary[\"lp__\", \"n_eff\"]\n", " if (n_eff >= target_neff) break;\n", " thin <- 2 * thin\n", " } \n", " ranks[n,] <- apply(rstan::extract(fit)$idsim, 2, sum) + 1\n", " }\n", " ranks\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start with a simple example where we set two parameters, generate data from them and then compare the posterior to these parameters. The comparison is done in the generated quantities block." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data {\n", "\tint n;\n", "\tvector[n] x;\n", "}\n", "\n", "transformed data {\n", "\treal beta_sim = normal_rng(0, 1);\n", "\treal sigma_sim = lognormal_rng(0, 1);\n", "\t\n", "\tvector[n] y_sim;\n", "\tfor (i in 1:n)\n", "\t\ty_sim[i] = normal_rng(x[i] * beta_sim, sigma_sim);\n", "}\n", "\n", "parameters {\n", "\treal beta;\n", "\treal sigma;\n", "}\n", "\n", "model {\n", "\tbeta ~ normal(0, 1);\n", "\tsigma ~ lognormal(0, 1);\n", " \ty_sim ~ normal(x * beta, sigma);\n", "}\n", "\n", "generated quantities {\n", "\tint idsim[2] = { beta < beta_sim, sigma < sigma_sim };\n", "}\n" ] } ], "source": [ "model.file <- \"_models/sbc_1.stan\"\n", "cat(readLines(model.file), sep=\"\\n\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model <- stan_model(model.file)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "?sampling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We run the loop $5000$ times and sample $100$ times. We also set some other parameters that Stan or SBC needs, such as init_thin which specifies the period for saving samples or target_neff which is the effective sample size we want to have." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "N <- 5000\n", "L <- 100\n", "init_thin <- 1\n", "max_thin <- 64\n", "target_neff <- .8 * L" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also define a histogram plotting method for the ranks." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "plot_fit <- function(fit)\n", "{\n", " as.data.frame(fit) %>% \n", " tidyr::gather(\"param\", \"value\") %>%\n", " ggplot(aes(value)) +\n", " geom_histogram(bins=30) +\n", " scale_y_continuous(\"Count\") + \n", " scale_x_continuous(\"\") + \n", " facet_grid(. ~ param) +\n", " ggthemes::theme_tufte()\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we run SBC and plot the ranks. beta;\n", "\treal sigma;\n", "}\n", "\n", "model {\n", "\tmu ~ normal(0, 10);\n", "\tsigma ~ lognormal(0, 5);\n", " \ty_sim ~ student_t(10, x * beta, sigma);\n", "}\n", "\n", "generated quantities {\n", "\tint idsim[2] = { beta < beta_sim, sigma < sigma_sim };\n", "}\n" ] } ], "source": [ "model.file <- \"_models/sbc_2.stan\"\n", "cat(readLines(model.file), sep=\"\\n\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": sigma_sim };\n", "}\n" ] } ], "source": [ "model.file <- \"_models/sbc_3.stan\"\n", "cat(readLines(model.file), sep=\"\\n\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": 