{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# An Minimal Example for MCMCDebugging.jl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start by defining a Beta-Binomial model, two posterior samplers (one bug-free and one buggy) and a test function to use." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "g (generic function with 1 method)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using Distributions, DynamicPPL\n", "\n", "# The marginal-conditional simulator defined by DynamicPPL\n", "# See README.md for the expected form of the model definition.\n", "@model function BetaBinomial(θ=missing, x=missing)\n", " θ ~ Beta(2, 3)\n", " x ~ Binomial(3, θ)\n", " return θ, x\n", "end\n", "\n", "# The successive-conditional simulator\n", "# 1. Bug-free posterior sampler\n", "# Beta(α + x, β + n - x) is the true posterior.\n", "rand_θ_given(x) = rand(Beta(2 + x, 3 + 3 - x))\n", "# 2. Buggy posterior sampler\n", "rand_θ_given_buggy(x) = rand_θ_given(min(3, x + 1))\n", "\n", "# Test function\n", "g(θ, x) = cat(θ, x; dims=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We firstly perform Geweke test for the bug-free sampler." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\n" ] }, { "data": { "text/plain": [ "Geweke (Joint Distribution) Test\n", "--------------------------------\n", "Results:\n", " Number of samples: 5000\n", " Parameter dimension: 1\n", " Data dimension: 1\n", " Statistic: [1.0069391903102192, 0.46797362344774]\n", " P-value: [0.3139639974176164, 0.6398034522528347]\n" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using MCMCDebugging\n", "\n", "res = perform(GewekeTest(5_000), BetaBinomial, rand_θ_given; g=g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The p-values look reasonably large ($\\geq0.1$)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also visualise the result via the Q-Q plot. Plotting functionality is supported via Plots.jl." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Quantile error: 0.003602000000000002\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using Plots\n", "\n", "plot(res, BetaBinomial(); size=(300, 300), title=\"Bug-free sampler\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try Geweke test on the buggy sampler now." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\n", "┌ Warning: Test function `g` is not provided. Statistic is not computed.\n", "└ @ MCMCDebugging /Users/kai/projects/TuringLang/MCMCDebugging.jl/src/geweke.jl:77\n" ] }, { "data": { "text/plain": [ "Geweke (Joint Distribution) Test\n", "--------------------------------\n", "Results:\n", " Number of samples: 5000\n", " Parameter dimension: 1\n", " Data dimension: 1\n", " Statistic: missing\n", " P-value: missing\n", "\n", "Test statistic is missing. Please use `compute_statistic!(res, g)` \n", "if you want compute statistic without rerun the simulation.\n", "\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res_buggy = perform(GewekeTest(5_000), BetaBinomial, rand_θ_given_buggy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Oops -- I also didn't passing the testing function `g`, which is wanted if we only want to make the Q-Q plot. But for the statistics and p-values, let's follow the recommendation to update the result." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Geweke (Joint Distribution) Test\n", "--------------------------------\n", "Results:\n", " Number of samples: 5000\n", " Parameter dimension: 1\n", " Data dimension: 1\n", " Statistic: [-41.00888978073274, -24.577724389227683]\n", " P-value: [0.0, 2.186427127224003e-133]\n" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compute_statistic!(res_buggy, g)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The p-values are much smaller, indicating there is strong evidence against the hypothesis that the sampler is bug-free." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notably, the visualization is also very informative." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Quantile error: 0.157802\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot(res_buggy, BetaBinomial(); size=(300, 300), title=\"Buggy sampler\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also check the maximum mean discrepancy (MMD) using the two results." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "┌ Info: MMD\n", "│ mmd_of(res) = 9.005168477094205e-5\n", "│ mmd_of(res_buggy) = 0.06918151052999455\n", "└ @ Main In[7]:1\n" ] } ], "source": [ "@info \"MMD\" mmd_of(res) mmd_of(res_buggy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the bug-free one attains a much smaller value, meaning the marginal-conditional simulator and the bug-free successive-conditional simulator admit more similar distributions." ] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.5.0", "language": "julia", "name": "julia-1.5" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.5.0" } }, "nbformat": 4, "nbformat_minor": 4 }