{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# An Minimal Example for MCMCDebugging.jl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start by defining a Beta-Binomial model, two posterior samplers (one bug-free and one buggy) and a test function to use."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"g (generic function with 1 method)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Distributions, DynamicPPL\n",
"\n",
"# The marginal-conditional simulator defined by DynamicPPL\n",
"# See README.md for the expected form of the model definition.\n",
"@model function BetaBinomial(θ=missing, x=missing)\n",
" θ ~ Beta(2, 3)\n",
" x ~ Binomial(3, θ)\n",
" return θ, x\n",
"end\n",
"\n",
"# The successive-conditional simulator\n",
"# 1. Bug-free posterior sampler\n",
"# Beta(α + x, β + n - x) is the true posterior.\n",
"rand_θ_given(x) = rand(Beta(2 + x, 3 + 3 - x))\n",
"# 2. Buggy posterior sampler\n",
"rand_θ_given_buggy(x) = rand_θ_given(min(3, x + 1))\n",
"\n",
"# Test function\n",
"g(θ, x) = cat(θ, x; dims=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We firstly perform Geweke test for the bug-free sampler."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\n"
]
},
{
"data": {
"text/plain": [
"Geweke (Joint Distribution) Test\n",
"--------------------------------\n",
"Results:\n",
" Number of samples: 5000\n",
" Parameter dimension: 1\n",
" Data dimension: 1\n",
" Statistic: [1.0069391903102192, 0.46797362344774]\n",
" P-value: [0.3139639974176164, 0.6398034522528347]\n"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using MCMCDebugging\n",
"\n",
"res = perform(GewekeTest(5_000), BetaBinomial, rand_θ_given; g=g)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The p-values look reasonably large ($\\geq0.1$)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also visualise the result via the Q-Q plot. Plotting functionality is supported via Plots.jl."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Quantile error: 0.003602000000000002\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Plots\n",
"\n",
"plot(res, BetaBinomial(); size=(300, 300), title=\"Bug-free sampler\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try Geweke test on the buggy sampler now."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\n",
"┌ Warning: Test function `g` is not provided. Statistic is not computed.\n",
"└ @ MCMCDebugging /Users/kai/projects/TuringLang/MCMCDebugging.jl/src/geweke.jl:77\n"
]
},
{
"data": {
"text/plain": [
"Geweke (Joint Distribution) Test\n",
"--------------------------------\n",
"Results:\n",
" Number of samples: 5000\n",
" Parameter dimension: 1\n",
" Data dimension: 1\n",
" Statistic: missing\n",
" P-value: missing\n",
"\n",
"Test statistic is missing. Please use `compute_statistic!(res, g)` \n",
"if you want compute statistic without rerun the simulation.\n",
"\n"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res_buggy = perform(GewekeTest(5_000), BetaBinomial, rand_θ_given_buggy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Oops -- I also didn't passing the testing function `g`, which is wanted if we only want to make the Q-Q plot. But for the statistics and p-values, let's follow the recommendation to update the result."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Geweke (Joint Distribution) Test\n",
"--------------------------------\n",
"Results:\n",
" Number of samples: 5000\n",
" Parameter dimension: 1\n",
" Data dimension: 1\n",
" Statistic: [-41.00888978073274, -24.577724389227683]\n",
" P-value: [0.0, 2.186427127224003e-133]\n"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compute_statistic!(res_buggy, g)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The p-values are much smaller, indicating there is strong evidence against the hypothesis that the sampler is bug-free."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notably, the visualization is also very informative."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Quantile error: 0.157802\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(res_buggy, BetaBinomial(); size=(300, 300), title=\"Buggy sampler\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also check the maximum mean discrepancy (MMD) using the two results."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: MMD\n",
"│ mmd_of(res) = 9.005168477094205e-5\n",
"│ mmd_of(res_buggy) = 0.06918151052999455\n",
"└ @ Main In[7]:1\n"
]
}
],
"source": [
"@info \"MMD\" mmd_of(res) mmd_of(res_buggy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, the bug-free one attains a much smaller value, meaning the marginal-conditional simulator and the bug-free successive-conditional simulator admit more similar distributions."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.5.0",
"language": "julia",
"name": "julia-1.5"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.5.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}