{ "cells": [ { "cell_type": "markdown", "id": "7c1c7652", "metadata": {}, "source": [ "# Truncated and folded distributions\n", "\n", "This tutorial will cover how to work with truncated and folded\n", "distributions in NumPyro.\n", "It is assumed that you're already familiar with the basics of NumPyro.\n", "To get the most out of this tutorial you'll need some background in probability.\n", "\n", "\n", "### Table of contents\n", "\n", "* [0. Setup](#0)\n", "* [1. What is a truncated distribution?](#1)\n", "* [2. What is a folded distribution?](#2)\n", "* [3. Sampling from truncated and folded distributions](#3)\n", "* [4. Ready-to-use truncated and folded distributions](#4)\n", "* [5. Building your own truncanted distributions](#5)\n", " * [5.1 Recap of NumPyro distributions](#5.1)\n", " * [5.2 Right-truncated normal](#5.2)\n", " * [5.3 Left-truncated Poisson](#5.3)\n", "* [6. References and related material](#references)\n", "\n", "\n", "### Setup \n", "To run this notebook, we are going to need the following imports" ] }, { "cell_type": "code", "execution_count": null, "id": "caed918e", "metadata": {}, "outputs": [], "source": [ "!pip install -q git+https://github.com/pyro-ppl/numpyro.git" ] }, { "cell_type": "code", "execution_count": 2, "id": "04fce45a", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import numpyro\n", "import numpyro.distributions as dist\n", "from jax import lax, random\n", "from jax.scipy.special import ndtr, ndtri\n", "from jax.scipy.stats import poisson, norm\n", "from numpyro.distributions import (\n", " constraints,\n", " Distribution,\n", " FoldedDistribution,\n", " SoftLaplace,\n", " StudentT,\n", " TruncatedDistribution,\n", " TruncatedNormal,\n", ")\n", "from numpyro.distributions.util import promote_shapes\n", "from numpyro.infer import DiscreteHMCGibbs, MCMC, NUTS, Predictive\n", "from scipy.stats import poisson as sp_poisson\n", "\n", "numpyro.enable_x64()\n", "RNG = random.PRNGKey(0)\n", "PRIOR_RNG, MCMC_RNG, PRED_RNG = random.split(RNG, 3)\n", "MCMC_KWARGS = dict(\n", " num_warmup=2000,\n", " num_samples=2000,\n", " num_chains=4,\n", " chain_method=\"sequential\",\n", ")" ] }, { "cell_type": "markdown", "id": "00d0ba5d", "metadata": {}, "source": [ "### 1. What are truncated distributions?\n", "\n", "\n", "The **support** of a probability distribution is the set of values\n", "in the domain with **non-zero probability**. For example, the\n", "support of the normal distribution is the whole real line (even if\n", "the density gets very small as we move away from the mean, technically\n", "speaking, it is never quite zero). The support of the uniform distribution,\n", "as coded in `jax.random.uniform` with the default arguments, is the interval $\\left[0, 1)\\right.$, because any\n", "value outside of that interval has zero probability. The support of the Poisson distribution is the set of non-negative integers, etc.\n", "\n", "**Truncating** a distribution makes its support smaller\n", "so that any value outside our desired domain has zero probability. In practice, this can be useful\n", "for modelling situations in which certain biases are introduced during data collection.\n", "For example, some physical detectors only get triggered when the signal is above some\n", "minimum threshold, or sometimes the detectors fail if the signal exceeds a certain value.\n", "As a result, the **observed values are constrained to be within a limited range of values**,\n", "even though the true signal does not have the same constraints.\n", "See, for example, section 3.1 of _Information Theory and Learning Algorithms_ by David Mackay.\n", "Naively, if $S$ is the support of the original density $p_Y(y)$, then by truncating to a new support\n", "$T\\subset S$ we are effectively defining a new random variable $Z$ for which the density is\n", "\n", "$$\n", "\\begin{align}\n", " p_Z(z) \\propto\n", " \\begin{cases}\n", " p_Y(z) & \\text{if $z$ is in $T$}\\\\\n", " 0 & \\text{if $z$ is outside $T$}\\\\\n", " \\end{cases}\n", "\\end{align}\n", "$$\n", "\n", "The reason for writing a $\\propto$ (proportional to) sign instead of a strict equation is that,\n", "defined in the above way, the resulting function does not integrate to $1$ and so it cannot be strictly considered a probability density. To make it into a probability density **we need to re-distribute the truncated mass**\n", "among the part of the distribution that remains. To do this, we simply re-weight every point by the same constant:\n", "\n", "$$\n", "\\begin{align}\n", " p_Z(z) =\n", " \\begin{cases}\n", " \\frac{1}{M}p_Y(z) & \\text{if $z$ is in $T$}\\\\\n", " 0 & \\text{if $z$ is outside $T$}\\\\\n", " \\end{cases}\n", "\\end{align}\n", "$$\n", "\n", "where $M = \\int_T p_Y(y)\\mathrm{d}y$.\n", "\n", "In practice, the truncation is often one-sided. This means that if, for example, the support before truncation is the interval $(a, b)$, then the support after truncation is of the form $(a, c)$ or $(c, b)$, with $a < c < b$. The figure below illustrates a left-sided truncation at zero of a normal distribution $N(1, 1)$.\n", "\n", "
\n", " \"truncated\"\n", "
\n", "\n", "The original distribution (left side) is truncated at the vertical dotted line. The truncated mass (orange region) is redistributed in the new support (right side image) so that the total area under the curve remains equal to 1 even after truncation. This method of re-weighting ensures that the density ratio between any two points, $p(a)/p(b)$ remains the same before and after the reweighting is done (as long as the points are inside the new support, of course).\n", "\n", "**Note**: Truncated data is different from _censored_ data. Censoring also hides values that are outside some desired support but, contrary to truncated data, we know when a value has been censored. The typical example is the household scale which does not report values above 300 pounds. Censored data will not be covered in this tutorial.\n", "\n", "### 2. What is a folded distribution? \n", "\n", "**Folding** is achieved by taking the absolute value of a random variable, $Z = \\lvert Y \\rvert$. This obviously modifies the support of the original distribution since negative values now have zero\n", "probability:\n", "\n", "$$\n", "\\begin{align}\n", " p_Z(z) =\n", " \\begin{cases}\n", " p_Y(z) + p_Y(-z) & \\text{if $z\\ge 0$}\\\\\n", " 0 & \\text{if $z\\lt 0$}\\\\\n", " \\end{cases}\n", "\\end{align}\n", "$$\n", "\n", "The figure below illustrates a folded normal distribution $N(1, 1)$.\n", "
\n", " \"folded\"\n", "
\n", "\n", "As you can see, the resulting distribution is different from the truncated case. In particular, the density ratio between points, $p(a)/p(b)$, is in general not the same after folding. For some examples in which folding is relevant see [references 3 and 4](#references)\n", "\n", "If the original distribution is symmetric around zero, then folding and truncating at zero have the same effect." ] }, { "cell_type": "markdown", "id": "af028f79", "metadata": {}, "source": [ "### 3. Sampling from truncated and folded distributions \n", "\n", "**Truncated distributions**\n", "\n", "Usually, we already have a sampler for the pre-truncated distribution (e.g. `np.random.normal`).\n", "So, a seemingly simple way of generating samples from the truncated distribution would be to\n", "sample from the original distribution, and then discard the samples that are outside the \n", "desired support. For example, if we wanted samples from a normal distribution truncated to the\n", "support $(-\\infty, 1)$, we'd simply do:\n", "\n", "```python\n", "upper = 1\n", "samples = np.random.normal(size=1000)\n", "truncated_samples = samples[samples < upper]\n", "```\n", "\n", "This is called **_rejection sampling_ but it is not very efficient**.\n", "If the region we truncated had a sufficiently high probability mass, then we'd be discarding a lot of samples and it might be a while before we accumulate sufficient samples for the truncated distribution. For example, the above snippet would only result in approximately 840 truncated samples even though we initially drew 1000. This can easily get a lot worse for other combinations of parameters.\n", "A **more efficient** approach is to use a method known as [inverse transform sampling](https://en.wikipedia.org/wiki/Inverse_transform_sampling).\n", "In this method, we first sample from a uniform distribution in (0, 1) and then transform those samples with the inverse cumulative distribution of our truncated distribution.\n", "This method ensures that no samples are wasted in the process, though it does have the slight complication that\n", "**we need to calculate the inverse CDF (ICDF)** of our truncated distribution. This might sound too complicated at first but, with a bit of algebra, we can often calculate the truncated ICDF in terms of the untruncated ICDF. The untruncated ICDF for many distributions is already available.\n", "\n", "**Folded distributions**\n", "\n", "This case is a lot simpler. Since we already have a sampler for the pre-folded distribution, all we need to do is to take the absolute value of those samples:\n", "\n", "```python\n", "samples = np.random.normal(size=1000)\n", "folded_samples = np.abs(samples)\n", "```" ] }, { "cell_type": "markdown", "id": "b461367f", "metadata": {}, "source": [ "### 4. Ready to use truncated and folded distributions \n", "\n", "The later sections in this tutorial will show you how to construct your own truncated and folded distributions, but you don't have to reinvent the wheel. NumPyro has [a bunch of truncated distributions](https://num.pyro.ai/en/stable/distributions.html#truncated-distributions) already implemented.\n", "\n", "Suppose, for example, that you want a normal distribution truncated on the right.\n", "For that purpose, we use the [TruncatedNormal](https://num.pyro.ai/en/stable/distributions.html#truncatednormal) distribution. The parameters of this distribution are `loc` and `scale`, corresponding to the `loc` and `scale` of the _untruncated_ normal, and `low` and/or `high` corresponding to the truncation points. Importantly, the `low` and `high` are **keyword only** arguments, only `loc` and `scale` are valid as positional arguments.\n", "This is how you can use this class in a model:" ] }, { "cell_type": "code", "execution_count": 3, "id": "42aee6ce", "metadata": {}, "outputs": [], "source": [ "def truncated_normal_model(num_observations, high, x=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal())\n", " scale = numpyro.sample(\"scale\", dist.LogNormal())\n", " with numpyro.plate(\"observations\", num_observations):\n", " numpyro.sample(\"x\", TruncatedNormal(loc, scale, high=high), obs=x)" ] }, { "cell_type": "markdown", "id": "2420a2a2", "metadata": {}, "source": [ "Let's now check that we can use this model in a typical MCMC workflow." ] }, { "cell_type": "markdown", "id": "cad8a521", "metadata": {}, "source": [ "**Prior simulation**" ] }, { "cell_type": "code", "execution_count": 4, "id": "3f809001", "metadata": {}, "outputs": [], "source": [ "high = 1.2\n", "num_observations = 250\n", "num_prior_samples = 100\n", "\n", "prior = Predictive(truncated_normal_model, num_samples=num_prior_samples)\n", "prior_samples = prior(PRIOR_RNG, num_observations, high)" ] }, { "cell_type": "markdown", "id": "5c5fa225", "metadata": {}, "source": [ "**Inference**\n", "\n", "To test our model, we run mcmc against some synthetic data.\n", "The synthetic data can be any arbitrary sample from the prior simulation." ] }, { "cell_type": "code", "execution_count": 5, "id": "e763e254", "metadata": {}, "outputs": [], "source": [ "# -- select an arbitrary prior sample as true data\n", "true_idx = 0\n", "true_loc = prior_samples[\"loc\"][true_idx]\n", "true_scale = prior_samples[\"scale\"][true_idx]\n", "true_x = prior_samples[\"x\"][true_idx]" ] }, { "cell_type": "code", "execution_count": 6, "id": "2c28a0f4", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEGCAYAAAB8Ys7jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAT8UlEQVR4nO3df3DU9Z3H8de7iAVEUUkETpNGWi+txRrbYCv9gdbYUixSQQ/sharYhjvJnNyBLaCCCqe5ESo6KDUojSe2wQIeoIAFTmA8bA9oU43S1F6GFhx+iYpSsIJ93x9Z5jgazLL7+bD5Ls/HDJPsN7uvfe8485qPn/3ud83dBQBIno/kegAAQGYocABIKAocABKKAgeAhKLAASChTjqeT1ZQUOAlJSXH8ykBIPE2btz4hrsXHnn8uBZ4SUmJNmzYcDyfEgBybsuWLZKkoqKijB5vZn9o7fhxLXAAOBGNGDFCkrR69eqguRQ4AER2++23R8mlwAEgsoqKiii5nIUCAJE1Nzerubk5eC4rcACIbOTIkZLYAweAxLnrrrui5FLgABBZ//79o+SyBw4AkTU1NampqSl4LitwAIhs1KhRktgDB3CYkvHPZvzYzTVXBpwEH+aee+6JkkuBA0Bk/fr1i5LLHjgARNbY2KjGxsbguazAASCy6upqSeyBA0Di3HfffVFyKXAAiKxv375RctkDB4DIGhoa1NDQEDyXFTgARDZmzBhJ7IEDQOLMmDEjSi4FDgCRlZWVRcllDxwAIlu/fr3Wr18fPJcVOABEduutt0piDxwAEmfmzJlRcilwAIisT58+UXLZAweAyNatW6d169YFz21zBW5mnSStlfTR1P3nu/tkMztXUr2k7pI2Shrh7u8HnxAAEm7ixImScrMH/mdJX3X3vWbWUdILZrZM0r9Iut/d683sR5JukjQr6HQAkAceeeSRKLltbqF4i72pmx1T/1zSVyXNTx1/XNK3YgwIAElXWlqq0tLS4Llp7YGbWQcza5C0U9IKSf8j6W13P5i6y1ZJZwefDgDywJo1a7RmzZrguWmdheLuH0gqM7PTJT0t6ZPpPoGZVUmqkqTi4uIMRgSAZJs8ebKkHJ8H7u5vm9nzki6RdLqZnZRahZ8j6fWjPKZWUq0klZeXe5bzAkDizJkzJ0pum1soZlaYWnnLzDpLukLSJknPS7omdbfrJS2KMiEAJFzv3r3Vu3fv4LnprMB7SXrczDqopfCfcvdnzOxVSfVmNlXSryU9Fnw6AMgDK1eulCRVVFQEzW2zwN39JUkXtXK8WdLFQacBgDw0depUSTkocABAdp544okouRQ4cIIqGf9sxo/dXHNlwEnyX1FRUZRcroUCAJEtX75cy5cvD57LChwAIqupqZEkDRgwIGguBQ4AkdXX10fJpcABILKePXtGyWUPHAAiW7JkiZYsWRI8lxU4AEQ2ffp0SdKgQYOC5lLgABDZ/Pnz275TBihwAIisoKAgSi574AAQ2cKFC7Vw4cLguazAASCyBx98UJI0ZMiQoLkUOABEtmhRnKttU+AAEFm3bt2i5LIHDgCRzZs3T/PmzQueywocACKbNWuWJGnYsGFBcylwAIhs6dKlUXIpcACIrEuXLlFy2QMHgMjmzp2ruXPnBs9lBQ4AkT366KOSpMrKyqC5FDgARLZixYoouRQ4AETWsWPHKLnsgQNAZHV1daqrqwue22aBm1mRmT1vZq+a2Stmdkvq+J1m9rqZNaT+DQw+HQDkgVgFns4WykFJY939V2Z2qqSNZnZoQ+d+d58WfCoAyCOrV6+Okttmgbv7NknbUr+/a2abJJ0dZRoAQNqOaQ/czEokXSTpl6lD1Wb2kpnNMbMzjvKYKjPbYGYbdu3ald20AJBAs2fP1uzZs4Pnpl3gZtZV0gJJY9z9HUmzJH1cUplaVujTW3ucu9e6e7m7lxcWFmY/MQAkTE4vZmVmHdVS3k+6+0JJcvcdh/19tqRngk8HAHlg5cqVUXLTOQvFJD0maZO7//Cw470Ou9vVkhrDjwcAOJp0VuBflDRC0stm1pA6NlHSdWZWJsklbZY0KsJ8AJB4Dz/8sCTp5ptvDpqbzlkoL0iyVv4U5/qIAJBnlixZIikHBQ4AyM6yZcui5PJRegBIKAocACJ74IEH9MADDwTPpcABILJVq1Zp1apVwXPZAweAyBYvXhwllxU4ACQUBQ4AkU2bNk3TpoW/cCtbKAAQ2YsvvhgllwIHgMgWLFgQJZctFABIKAocACKrqalRTU1N8Fy2UAAgsoaGhii5FDgARFZfXx8lly0UAEgoChwAIpsyZYqmTJkSPJctFACIrKmpKUouBQ4Akc2dOzdKLlsoAJBQFDgARDZp0iRNmjQpeC5bKICkkvHPZvX4zTVXBpoE+WjLli1RcilwAIjsxz/+cZRctlAAIKHaLHAzKzKz583sVTN7xcxuSR0/08xWmNlrqZ9nxB8XAJJnwoQJmjBhQvDcdFbgByWNdffzJX1B0mgzO1/SeEmr3P08SatStwEAR9i9e7d2794dPLfNPXB33yZpW+r3d81sk6SzJQ2WdGnqbo9LWi3pB8EnBICEq62tjZJ7TG9imlmJpIsk/VJSj1S5S9J2ST2O8pgqSVWSVFxcnPGgAPJHNmf9cMbP/0n7TUwz6yppgaQx7v7O4X9zd5fkrT3O3WvdvdzdywsLC7MaFgCSaNy4cRo3blzw3LQK3Mw6qqW8n3T3hanDO8ysV+rvvSTtDD4dAOSB/fv3a//+/cFz29xCMTOT9JikTe7+w8P+tFjS9ZJqUj8XBZ8OAPLAQw89FCU3nT3wL0oaIellM2tIHZuoluJ+ysxukvQHSX8XZUIAQKvSOQvlBUl2lD9fHnYcAMg/Y8aMkSTNmDEjaC6fxASAhOJaKEAAnBaHDxN65X0IK3AASCgKHAAiGz16tEaPHh08ly0UAIisc+fOUXIpcACIbNq0aVFy2UIBgIRiBQ7kWLZf54b2r6qqSlL4qxJS4AAQWffu3aPkUuAAENm9994bJZc9cABIKAocACK78cYbdeONNwbPZQsFACIrKiqKkkuBA0Bkd999d5RctlAAIKEocACIrLKyUpWVlcFz2UIBgMhKS0uj5FLgABDZHXfcESWXLRQASCgKHAAiGz58uIYPHx48ly0UAIisrKwsSi4FDgCRjR8/PkouWygAkFBtFriZzTGznWbWeNixO83sdTNrSP0bGHdMAEiuoUOHaujQocFz09lCqZM0U9K/H3H8fneP8z1BAJBHLrnkkii5bRa4u681s5Iozw4AJ4Bx48ZFyc3mTcxqM/uOpA2Sxrr7W63dycyqJFVJUnFxcRZPBwC5k+1X322uuTLQJP8n0zcxZ0n6uKQySdskTT/aHd291t3L3b28sLAww6cDgOTaueBuXXXVVcFzM1qBu/uOQ7+b2WxJzwSbCADyTKePXajLL/908NyMCtzMern7ttTNqyU1ftj9AeBEdlr5YN1yS/gtlDYL3Mx+KulSSQVmtlXSZEmXmlmZJJe0WdKo4JMBAD5UOmehXNfK4ccizAIAeWnHU5P1jd/M1LJly4Lm8lF6AMcs2zMyTjRdPnGxBg3qEzyXAgeAyE797JW6+eb2cxohACDHKHAAiGxH/W2qqKgInssWCgBE1uWTX9awoZ8JnkuBA0Bkp5YN0Pe+l4PzwAGgPcnmDJgY1yPJJfbAASCy7T8Zr0svvTR4LitwAIis6wUVuuHaC4PnUuAAEFnXCyp0ww2cBw4AieMfHNSBAweC51LgABDZjnm364orrgieyxYKAETW9cKv67vDyoLnUuAAEFnXT1+mykr2wAEgcf5y4D3t27cveC4FDgCR7fzZnRo4cGDwXLZQACCyUy8aqH/89meD51LgABDZKZ/6ioYNYw8cABLnL3/+k/bs2RM8lxU42hUuVISYcvVVcDsXTNHgXz+s1atXB82lwAEgstM+d5X+acTngudS4AAQWZfSfhoyJAd74GY2x8x2mlnjYcfONLMVZvZa6ucZwScDgDzxwb49euONN4LnpvMmZp2kAUccGy9plbufJ2lV6jYAoBW7/uNeXXPNNcFz2yxwd18r6c0jDg+W9Hjq98clfSvsWACQP067+GqNHTs2eG6me+A93H1b6vftknoc7Y5mViWpSpKKi4szfDokSa7e6c/V8wJt6fKJz2vQoHZ4Hri7uyT/kL/Xunu5u5cXFhZm+3QAkDgf7H1L27dvD56baYHvMLNekpT6uTPcSACQX3Yt/jcNHz48eG6mBb5Y0vWp36+XtCjMOACQf7p94VqNHx/+XI90TiP8qaQXJZWa2VYzu0lSjaQrzOw1SRWp2wCAVnTu/TkNGHDkyXzZa/NNTHe/7ih/ujzwLACQlw6+s0tbtmxRUVFR0Fw+iZmnuKYI0H688cx0jXi1jmuhAEDSdOs3XLd/9/PBc7mcLABE1rmkTBUVFcFzKXAAiOzA29vV3NwcPJcCB4DIdi+doZEjRwbPZQ8cACI7/Ut/r7tGXRI8lxU4AETWqfgC9e/fP3guBQ4AkR3YvVVNTU3BcylwAIhs93MzNWrUqOC57IEDQGSnf+V63XNzv+C5rMABILJO53xK/fpR4ACQOO/v2qzGxsa273iMKHAAiOzNFT9SdXV18Fz2wAEgsjMuG6n7qr8UPJcVOABE9tFef6u+ffsGz6XAASCy93c0q6GhIXguBQ4Akb25qlZjxowJnsseOABEdublVZpxy5eD57ICB4DITu7RW2VlZcFzWYHjr2TzdWwA/tqft/1O69efFfyNTFbgABDZW8/P0a233ho8lxU4AER25hX/oJn/HP5yslkVuJltlvSupA8kHXT38hBDAUA+ObmwRH369AmeG2IL5TJ3L6O8AaB1723dpHXr1gXPZQ8cACJ7e+3jmjhxYvDcbAvcJf3czDaaWVVrdzCzKjPbYGYbdu3aleXTAUDydP96tR555JHgudkW+Jfc/bOSviFptJl95cg7uHutu5e7e3lhYWGWTwcAydOx+zkqLS0NnptVgbv766mfOyU9LeniEEMBQD55748va82aNcFzMy5wMzvFzE499Lukr0kKf8VyAEi4t194UpMnTw6em81phD0kPW1mh3J+4u7Lg0wFAHmk+8AxmvP9y4LnZlzg7t4s6cKAswBAXup4ek/17t07eC6nEQJAZPs3N2jlypXBc/kofTvGRaWA/LBnXb2mbl+hioqKoLmswAEgsoJvjtUTTzwRPJcCB4DITjqtUEVFRcFzKXAAiGx/80YtXx7+JD0KHAAi2/OLn6mmpiZ4LgUOAJEVXvUD1dfXB8/lLJTIOJMEQIeuZ6hnz57Bc1mBA0Bk+37/Sy1ZsiR4LgUOAJG9899Pa/r06cFzKXAAiKzwWxM0f/784LkUOABE1qFLNxUUFATPpcABILJ9Teu0cOHC4LknxFko2Z4JsrnmykCTADgRvbNxsR586780ZMiQoLmswAEgsrOG3qFFixYFz6XAASCyj3z0FHXr1i18bvBEAMD/86dNazVv3rzguRQ4AET27q+XatasWcFzKXAAiOysa+/U0qVLg+dS4AAQ2Uc6dlKXLl2C5ybmNMJcXhSKC1IByMbeV57X3LlvqbKyMmguK3AAiGzvb57To48+GjyXAgeAyHoMm6oVK1YEz82qwM1sgJk1mdnvzWx8qKEAIJ9Yh5PUsWPH4LkZF7iZdZD0kKRvSDpf0nVmdn6owQAgX+x9eaXq6uqC52azAr9Y0u/dvdnd35dUL2lwmLEAIH/EKvBszkI5W9KWw25vlfT5I+9kZlWSqlI395pZUxbP2R4USHoj10MExOtp33g97Vvar2fNFsnMMn2ej7V2MPpphO5eK6k29vMcL2a2wd3Lcz1HKLye9o3X077l+vVks4XyuqSiw26fkzoGADgOsinw9ZLOM7NzzexkScMlLQ4zFgCgLRlvobj7QTOrlvScpA6S5rj7K8Ema7/yZjsohdfTvvF62recvh5z91w+PwAgQ3wSEwASigIHgISiwDNkZmPNzM2sINezZMPMppjZS2bWYGY/N7O/yfVM2TKz+8zst6nX9bSZnZ7rmbJhZtea2Stm9hczS+wpePl06Q0zm2NmO82sMZdzUOAZMLMiSV+T9MdczxLAfe7+GXcvk/SMpEk5nieEFZL6uPtnJP1O0oQcz5OtRklDJK3N9SCZysNLb9RJGpDrISjwzNwv6fuSEv8OsLu/c9jNU5Qfr+nn7n4wdfMXavmMQmK5+yZ3T/onmPPq0hvuvlbSm7meIzFf6NBemNlgSa+7+2+y+Fhsu2Jm/yrpO5L2SLosx+OENlJS+G+TxbFK69IbODYUeCvMbKWknq386TZJE9WyfZIYH/Z63H2Ru98m6TYzmyCpWtLk4zpgBtp6Tan73CbpoKQnj+dsmUjn9QBHosBb4e4VrR03swsknSvp0Or7HEm/MrOL3X37cRzxmBzt9bTiSUlLlYACb+s1mdkNkr4p6XJPwIcdjuG/UVJx6Y0IKPBj4O4vSzrr0G0z2yyp3N0Te3U1MzvP3V9L3Rws6be5nCcEMxuglvco+rv7vlzPA0mHXXpDLcU9XNK3cztS8vEmJmrMrNHMXlLL1tAtuR4ogJmSTpW0InV65I9yPVA2zOxqM9sq6RJJz5rZc7me6Vil3lQ+dOmNTZKeSvKlN8zsp5JelFRqZlvN7KaczJGA/7sEALSCFTgAJBQFDgAJRYEDQEJR4ACQUBQ4ACQUBQ4ACUWBA0BCUeA4oZlZ39R1wzuZ2Smp6273yfVcQDr4IA9OeGY2VVInSZ0lbXX3e3M8EpAWChwnPDM7WS3X6nhPUj93/yDHIwFpYQsFkLpL6qqW66d0yvEsQNpYgeOEZ2aL1fINMedK6uXu1TkeCUgLl5PFCc3MviPpgLv/JPW9jevM7Kvu/p+5ng1oCytwAEgo9sABIKEocABIKAocABKKAgeAhKLAASChKHAASCgKHAAS6n8B2kQJnfPwppQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.hist(true_x.copy(), bins=20)\n", "plt.axvline(high, linestyle=\":\", color=\"k\")\n", "plt.xlabel(\"x\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "id": "cb7a51be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1909.24it/s, 1 steps of size 5.65e-01. acc. prob=0.93]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 10214.14it/s, 3 steps of size 5.16e-01. acc. prob=0.95]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 15102.95it/s, 1 steps of size 6.42e-01. acc. prob=0.90]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 16522.03it/s, 3 steps of size 6.39e-01. acc. prob=0.90]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " loc -0.58 0.15 -0.59 -0.82 -0.35 2883.69 1.00\n", " scale 1.49 0.11 1.48 1.32 1.66 3037.78 1.00\n", "\n", "Number of divergences: 0\n", "True loc : -0.56\n", "True scale: 1.4\n" ] } ], "source": [ "# --- Run MCMC and check estimates and diagnostics\n", "mcmc = MCMC(NUTS(truncated_normal_model), **MCMC_KWARGS)\n", "mcmc.run(MCMC_RNG, num_observations, high, true_x)\n", "mcmc.print_summary()\n", "\n", "# --- Compare to ground truth\n", "print(f\"True loc : {true_loc:3.2}\")\n", "print(f\"True scale: {true_scale:3.2}\")" ] }, { "cell_type": "markdown", "id": "cac57188", "metadata": {}, "source": [ "**Removing the truncation**" ] }, { "cell_type": "markdown", "id": "796870b7", "metadata": {}, "source": [ "Once we have inferred the parameters of our model, a common task is to understand what the data would look like _without_ the truncation. In this example, this is easily done by simply \"pushing\" the value of `high` to infinity." ] }, { "cell_type": "code", "execution_count": 8, "id": "5d5c9763", "metadata": {}, "outputs": [], "source": [ "pred = Predictive(truncated_normal_model, posterior_samples=mcmc.get_samples())\n", "pred_samples = pred(PRED_RNG, num_observations, high=float(\"inf\"))" ] }, { "cell_type": "markdown", "id": "829f9f12", "metadata": {}, "source": [ "Let's finally plot these samples and compare them to the original, observed data." ] }, { "cell_type": "code", "execution_count": 9, "id": "5cf4724e", "metadata": {}, "outputs": [], "source": [ "# thin the samples to not saturate matplotlib\n", "samples_thinned = pred_samples[\"x\"].ravel()[::1000]" ] }, { "cell_type": "code", "execution_count": 10, "id": "286a3c03", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f, axes = plt.subplots(1, 2, figsize=(15, 5), sharex=True)\n", "\n", "axes[0].hist(\n", " samples_thinned.copy(), label=\"Untruncated posterior\", bins=20, density=True\n", ")\n", "axes[0].set_title(\"Untruncated posterior\")\n", "\n", "vals, bins, _ = axes[1].hist(\n", " samples_thinned[samples_thinned < high].copy(),\n", " label=\"Tail of untruncated posterior\",\n", " bins=10,\n", " density=True,\n", ")\n", "axes[1].hist(\n", " true_x.copy(), bins=bins, label=\"Observed, truncated data\", density=True, alpha=0.5\n", ")\n", "axes[1].set_title(\"Comparison to observed data\")\n", "\n", "for ax in axes:\n", " ax.axvline(high, linestyle=\":\", color=\"k\", label=\"Truncation point\")\n", " ax.legend()\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "17ee66e7", "metadata": {}, "source": [ "The plot on the left shows data simulated from the posterior distribution with the truncation removed, so we are able to see how the data would look like if it were not truncated. To sense check this, we discard the simulated samples that are above the truncation point and make histogram of those and compare it to a histogram of the true data (right plot)." ] }, { "cell_type": "markdown", "id": "80b5e7c0", "metadata": {}, "source": [ "**The TruncatedDistribution class**\n", "\n", "The source code for the [TruncatedNormal](https://num.pyro.ai/en/stable/distributions.html#truncatednormal) in NumPyro uses a class called [TruncatedDistribution](https://num.pyro.ai/en/stable/distributions.html#truncateddistribution) which abstracts away the logic for `sample` and `log_prob` that\n", "we will discuss in the next sections. At the moment, though, this logic only works continuous, symmetric distributions with _real_ support.\n", "\n", "We can use this class to quickly construct other truncated distributions. For example, if we need a truncated [SoftLaplace](https://num.pyro.ai/en/stable/distributions.html#softlaplace) we can use the following pattern:" ] }, { "cell_type": "code", "execution_count": 11, "id": "725d6d5d", "metadata": {}, "outputs": [], "source": [ "def TruncatedSoftLaplace(\n", " loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None\n", "):\n", " return TruncatedDistribution(\n", " base_dist=SoftLaplace(loc, scale),\n", " low=low,\n", " high=high,\n", " validate_args=validate_args,\n", " )" ] }, { "cell_type": "code", "execution_count": 12, "id": "f217d43e", "metadata": {}, "outputs": [], "source": [ "def truncated_soft_laplace_model(num_observations, high, x=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal())\n", " scale = numpyro.sample(\"scale\", dist.LogNormal())\n", " with numpyro.plate(\"obs\", num_observations):\n", " numpyro.sample(\"x\", TruncatedSoftLaplace(loc, scale, high=high), obs=x)" ] }, { "cell_type": "markdown", "id": "74e05391", "metadata": {}, "source": [ "And, as before, we check that we can use this model in the steps of a typical workflow:" ] }, { "cell_type": "code", "execution_count": 13, "id": "01d3c464", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1745.70it/s, 1 steps of size 6.78e-01. acc. prob=0.93]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 9294.56it/s, 1 steps of size 7.02e-01. acc. prob=0.93]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 10412.30it/s, 1 steps of size 7.20e-01. acc. prob=0.92]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 10583.85it/s, 3 steps of size 7.01e-01. acc. prob=0.93]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " loc -0.37 0.17 -0.38 -0.65 -0.10 4034.96 1.00\n", " scale 1.46 0.12 1.45 1.27 1.65 3618.77 1.00\n", "\n", "Number of divergences: 0\n", "True loc : -0.56\n", "True scale: 1.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "high = 2.3\n", "num_observations = 200\n", "num_prior_samples = 100\n", "\n", "prior = Predictive(truncated_soft_laplace_model, num_samples=num_prior_samples)\n", "prior_samples = prior(PRIOR_RNG, num_observations, high)\n", "\n", "true_idx = 0\n", "true_x = prior_samples[\"x\"][true_idx]\n", "true_loc = prior_samples[\"loc\"][true_idx]\n", "true_scale = prior_samples[\"scale\"][true_idx]\n", "\n", "mcmc = MCMC(\n", " NUTS(truncated_soft_laplace_model),\n", " **MCMC_KWARGS,\n", ")\n", "\n", "mcmc.run(\n", " MCMC_RNG,\n", " num_observations,\n", " high,\n", " true_x,\n", ")\n", "\n", "mcmc.print_summary()\n", "\n", "print(f\"True loc : {true_loc:3.2}\")\n", "print(f\"True scale: {true_scale:3.2}\")" ] }, { "cell_type": "markdown", "id": "8e22e5a7", "metadata": {}, "source": [ "**Important**\n", "\n", "The `sample` method of the [TruncatedDistribution](https://num.pyro.ai/en/stable/distributions.html#truncateddistribution) class relies on inverse-transform sampling.\n", "This has the implicit requirement that the base distribution should have an `icdf` method already available.\n", "If this is not the case, we will not be able to call the `sample` method on any instances of our distribution, nor use it with the `Predictive` class.\n", "However, the `log_prob` method only depends on the `cdf` method (which is more frequently available than the `icdf`). If the `log_prob` method is available, then we _can_ use our distribution as prior/likelihood in a model." ] }, { "cell_type": "markdown", "id": "91121a38", "metadata": {}, "source": [ "**The FoldedDistribution class**\n", "\n", "Similar to truncated distributions, NumPyro has the [FoldedDistribution](https://num.pyro.ai/en/stable/distributions.html#foldeddistribution) class to help you quickly construct folded distributions. Popular examples of folded distributions are the so-called \"half-normal\", \"half-student\" or \"half-cauchy\". As the name suggests, these distributions keep only (the positive) _half_ of the distribution. Implicit in the name of these \"half\" distributions is that they are centered at zero before folding. But, of course, you can fold a distribution even if its not centered at zero. For instance, this is how you would define a folded student-t distribution." ] }, { "cell_type": "code", "execution_count": 14, "id": "e7ebd8c5", "metadata": {}, "outputs": [], "source": [ "def FoldedStudentT(df, loc=0.0, scale=1.0):\n", " return FoldedDistribution(StudentT(df, loc=loc, scale=scale))" ] }, { "cell_type": "code", "execution_count": 15, "id": "1857b04c", "metadata": {}, "outputs": [], "source": [ "def folded_student_model(num_observations, x=None):\n", " df = numpyro.sample(\"df\", dist.Gamma(6, 2))\n", " loc = numpyro.sample(\"loc\", dist.Normal())\n", " scale = numpyro.sample(\"scale\", dist.LogNormal())\n", " with numpyro.plate(\"obs\", num_observations):\n", " numpyro.sample(\"x\", FoldedStudentT(df, loc, scale), obs=x)" ] }, { "cell_type": "markdown", "id": "fbc2e678", "metadata": {}, "source": [ "And we check that we can use our distribution in a typical workflow:" ] }, { "cell_type": "code", "execution_count": 16, "id": "fd9f0869", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1343.54it/s, 7 steps of size 3.51e-01. acc. prob=0.75]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 3644.99it/s, 7 steps of size 3.56e-01. acc. prob=0.73]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 3137.13it/s, 7 steps of size 2.62e-01. acc. prob=0.91]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 3028.93it/s, 7 steps of size 1.85e-01. acc. prob=0.96]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " df 3.12 0.52 3.07 2.30 3.97 2057.60 1.00\n", " loc -0.02 0.88 -0.03 -1.28 1.34 925.84 1.01\n", " scale 2.23 0.21 2.25 1.89 2.57 1677.38 1.00\n", "\n", "Number of divergences: 33\n", "True df : 3.01\n", "True loc : 0.37\n", "True scale: 2.41\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# --- prior sampling\n", "num_observations = 500\n", "num_prior_samples = 100\n", "prior = Predictive(folded_student_model, num_samples=num_prior_samples)\n", "prior_samples = prior(PRIOR_RNG, num_observations)\n", "\n", "\n", "# --- choose any prior sample as the ground truth\n", "true_idx = 0\n", "true_df = prior_samples[\"df\"][true_idx]\n", "true_loc = prior_samples[\"loc\"][true_idx]\n", "true_scale = prior_samples[\"scale\"][true_idx]\n", "true_x = prior_samples[\"x\"][true_idx]\n", "\n", "# --- do inference with MCMC\n", "mcmc = MCMC(\n", " NUTS(folded_student_model),\n", " **MCMC_KWARGS,\n", ")\n", "mcmc.run(MCMC_RNG, num_observations, true_x)\n", "\n", "# --- Check diagostics\n", "mcmc.print_summary()\n", "\n", "# --- Compare to ground truth:\n", "print(f\"True df : {true_df:3.2f}\")\n", "print(f\"True loc : {true_loc:3.2f}\")\n", "print(f\"True scale: {true_scale:3.2f}\")" ] }, { "cell_type": "markdown", "id": "faf9b651", "metadata": {}, "source": [ "### 5. Building your own truncated distribution \n", "\n", "If the\n", "[TruncatedDistribution](https://num.pyro.ai/en/stable/distributions.html#truncateddistribution) and\n", "[FoldedDistribution](https://num.pyro.ai/en/stable/distributions.html#foldeddistribution)\n", "classes are not sufficient to solve your problem,\n", "you might want to look into writing your own truncated distribution from the ground up.\n", "This can be a tedious process, so this section will give you some guidance and examples to help you with it.\n", "\n", "\n", "#### 5.1 Recap of NumPyro distributions \n", "\n", "A NumPyro distribution should subclass [Distribution](https://num.pyro.ai/en/stable/distributions.html#distribution) and implement a few basic ingredients:\n", "\n", "\n", "**Class attributes**\n", "\n", "The class attributes serve a few different purposes. Here we will mainly care about two:\n", "1. `arg_constraints`: Impose some requirements on the parameters of the distribution. Errors are raised at instantiation time if the parameters passed do not satisfy the constraints.\n", "2. `support`: It is used in some inference algorithms like MCMC and SVI with auto-guides, where we need to perform the algorithm in the unconstrained space. Knowing the support, we can automatically reparametrize things under the hood.\n", "\n", "We'll explain other class attributes as we go.\n", "\n", "**The** `__init__` **method**\n", "\n", "This is where we define the parameters of the distribution.\n", "We also use `jax` and `lax` to promote the parameters to shapes that are valid for broadcasting.\n", "The `__init__` method of the parent class is also required because that's where the validation of our parameters is done.\n", "\n", "**The** `log_prob` **method**\n", "\n", "Implementing the `log_prob` method ensures that we can do inference. As the name suggests, this method returns the logarithm of the density evaluated at the argument.\n", "\n", "**The** `sample` **method**\n", "\n", "This method is used for drawing independent samples from our distribution. It is particularly useful for doing prior and posterior predictive checks. Note, in particular, that this method is not needed if you only need to use your distribution as prior in a model - the `log_prob` method will suffice.\n", "\n", "\n", "The place-holder code for any of our implementations can be written as\n", "```python\n", "class MyDistribution(Distribution):\n", " # class attributes\n", " arg_constraints = {}\n", " support = None\n", " def __init__(self):\n", " pass\n", " \n", " def log_prob(self, value):\n", " pass\n", " \n", " def sample(self, key, sample_shape=()):\n", " pass\n", "```\n", "\n", "\n", "#### 5.2 Example: Right-truncated normal \n", "\n", "We are going to modify a normal distribution so that its new support is\n", "of the form `(-inf, high)`, with `high` a real number. This could be done with the `TruncatedNormal` distribution but, for the sake of illustration, we are not going to rely on it.\n", "We'll call our distribution `RightTruncatedNormal`. Let's write the skeleton code and then proceed to fill in the blanks.\n", "\n", "```python\n", "class RightTruncatedNormal(Distribution):\n", " # \n", " def __init__(self):\n", " pass\n", " \n", " def log_prob(self, value):\n", " pass\n", " \n", " def sample(self, key, sample_shape=()):\n", " pass\n", "```\n", " \n", "\n", "**Class attributes**\n", "\n", "Remember that a non-truncated normal distribution is specified in NumPyro by two parameters, `loc` and `scale`,\n", "which correspond to the mean and standard deviation.\n", "Looking at the [source code](https://github.com/pyro-ppl/numpyro/blob/0664c2d2dd1eb5f41ea6a0bcef91e5fa2a417ce5/numpyro/distributions/continuous.py#L1337) for the `Normal` distribution we see the following lines:\n", "\n", "```python\n", "arg_constraints = {\"loc\": constraints.real, \"scale\": constraints.positive}\n", "support = constraints.real\n", "reparametrized_params = [\"loc\", \"scale\"]\n", "```\n", "\n", "The `reparametrized_params` attribute is used by variational inference algorithms when constructing gradient estimators. The parameters of many common distributions with continuous support (e.g. the Normal distribution) are reparameterizable, while the parameters of discrete distributions are not. Note that `reparametrized_params` is irrelevant for MCMC algorithms like HMC. See [SVI Part III](https://pyro.ai/examples/svi_part_iii.html#Tricky-Case:-Non-reparameterizable-Random-Variables) for more details.\n", "\n", "We must adapt these attributes to our case by including the `\"high\"` parameter, but there are two issues we need to deal with:\n", "\n", "1. `constraints.real` is a bit too restrictive. We'd like `jnp.inf` to be a valid value for `high` (equivalent to no truncation), but at the moment infinity is not a valid real number. We deal with this situation by defining our own constraint. The source code for `constraints.real` is easy to imitate:\n", "\n", "```python\n", "class _RightExtendedReal(constraints.Constraint):\n", " \"\"\"\n", " Any number in the interval (-inf, inf].\n", " \"\"\"\n", " def __call__(self, x):\n", " return (x == x) & (x != float(\"-inf\"))\n", " \n", " def feasible_like(self, prototype):\n", " return jnp.zeros_like(prototype)\n", "\n", "right_extended_real = _RightExtendedReal()\n", "```\n", "\n", "2. `support` can no longer be a class attribute as it will depend on the value of `high`. So instead we implement it as a dependent property.\n", "\n", "Our distribution then looks as follows:\n", "```python\n", "class RightTruncatedNormal(Distribution):\n", " arg_constraints = {\n", " \"loc\": constraints.real,\n", " \"scale\": constraints.positive,\n", " \"high\": right_extended_real,\n", " }\n", " reparametrized_params = [\"loc\", \"scale\", \"high\"]\n", " \n", " # ...\n", " \n", " @constraints.dependent_property\n", " def support(self):\n", " return constraints.lower_than(self.high)\n", "```\n", "\n", "**The** `__init__` **method**\n", "\n", "Once again we take inspiration from the [source code](https://github.com/pyro-ppl/numpyro/blob/0664c2d2dd1eb5f41ea6a0bcef91e5fa2a417ce5/numpyro/distributions/continuous.py#L1342) for the normal distribution. The key point is the use of `lax` and `jax` to check the shapes of the arguments passed and make sure that such shapes are consistent for broadcasting. We follow the same pattern for our use case -- all we need to do is include the `high` parameter.\n", "\n", "In the source implementation of `Normal`, both parameters `loc` and `scale` are given defaults so that one recovers a standard normal distribution if no arguments are specified. In the same spirit, we choose `float(\"inf\")` as a default for `high` which would be equivalent to no truncation.\n", "\n", "```python\n", "# ...\n", " def __init__(self, loc=0.0, scale=1.0, high=float(\"inf\"), validate_args=None):\n", " batch_shape = lax.broadcast_shapes(\n", " jnp.shape(loc),\n", " jnp.shape(scale),\n", " jnp.shape(high),\n", " )\n", " self.loc, self.scale, self.high = promote_shapes(loc, scale, high)\n", " super().__init__(batch_shape, validate_args=validate_args)\n", "# ...\n", "```\n", "\n", "**The** `log_prob` **method**\n", "\n", "For a truncated distribution, the log density is given by\n", "\n", "$$\n", "\\begin{align}\n", " \\log p_Z(z) =\n", " \\begin{cases}\n", " \\log p_Y(z) - \\log M & \\text{if $z$ is in $T$}\\\\\n", " -\\infty & \\text{if $z$ is outside $T$}\\\\\n", " \\end{cases}\n", "\\end{align}\n", "$$\n", "\n", "where, again, $p_Z$ is the density of the truncated distribution, $p_Y$ is the density before truncation, and $M = \\int_T p_Y(y)\\mathrm{d}y$. For the specific case of truncating the normal distribution to the interval `(-inf, high)`, the constant $M$ is equal to the cumulative density evaluated at the truncation point. We can easily implement this log-density method because `jax.scipy.stats` already has a `norm` module that we can use.\n", "\n", "```python\n", "# ...\n", " def log_prob(self, value):\n", " log_m = norm.logcdf(self.high, self.loc, self.scale)\n", " log_p = norm.logpdf(value, self.loc, self.scale)\n", " return jnp.where(value < self.high, log_p - log_m, -jnp.inf)\n", "# ...\n", "```\n", "\n", "**The** `sample` **method**\n", "\n", "To implement the sample method using inverse-transform sampling, we need to also implement the inverse cumulative distribution function. For this, we can use the `ndtri` function that lives inside `jax.scipy.special`. This function returns the inverse cdf for the standard normal distribution. We can do a bit of algebra to obtain the inverse cdf of the truncated, non-standard normal. First recall that if $X\\sim Normal(0, 1)$ and $Y = \\mu + \\sigma X$, then $Y\\sim Normal(\\mu, \\sigma)$. Then if $Z$ is the truncated $Y$, its cumulative density is given by:\n", "\n", "$$\n", "\\begin{align}\n", "F_Z(y) &= \\int_{-\\infty}^{y}p_Z(r)dr\\newline\n", " &= \\frac{1}{M}\\int_{-\\infty}^{y}p_Y(s)ds \\quad\\text{if $y < high$} \\newline\n", " &= \\frac{1}{M}F_Y(y)\n", "\\end{align}\n", "$$\n", "\n", "And so its inverse is\n", "\n", "$$\n", "\\begin{align}\n", "F_Z^{-1}(u) = \\left(\\frac{1}{M}F_Y\\right)^{-1}(u)\n", " = F_Y^{-1}(M u)\n", " = F_{\\mu + \\sigma X}^{-1}(Mu)\n", " = \\mu + \\sigma F_X^{-1}(Mu)\n", "\\end{align}\n", "$$\n", "\n", "The translation of the above math into code is\n", "\n", "```python\n", "# ...\n", " def sample(self, key, sample_shape=()):\n", " shape = sample_shape + self.batch_shape\n", " minval = jnp.finfo(jnp.result_type(float)).tiny\n", " u = random.uniform(key, shape, minval=minval)\n", " return self.icdf(u)\n", "\n", "\n", " def icdf(self, u):\n", " m = norm.cdf(self.high, self.loc, self.scale)\n", " return self.loc + self.scale * ndtri(m * u)\n", "\n", "```\n", "\n", "With everything in place, the final implementation is as below." ] }, { "cell_type": "code", "execution_count": 17, "id": "bef4c680", "metadata": {}, "outputs": [], "source": [ "class _RightExtendedReal(constraints.Constraint):\n", " \"\"\"\n", " Any number in the interval (-inf, inf].\n", " \"\"\"\n", "\n", " def __call__(self, x):\n", " return (x == x) & (x != float(\"-inf\"))\n", "\n", " def feasible_like(self, prototype):\n", " return jnp.zeros_like(prototype)\n", "\n", "\n", "right_extended_real = _RightExtendedReal()\n", "\n", "\n", "class RightTruncatedNormal(Distribution):\n", " \"\"\"\n", " A truncated Normal distribution.\n", " :param numpy.ndarray loc: location parameter of the untruncated normal\n", " :param numpy.ndarray scale: scale parameter of the untruncated normal\n", " :param numpy.ndarray high: point at which the truncation happens\n", " \"\"\"\n", "\n", " arg_constraints = {\n", " \"loc\": constraints.real,\n", " \"scale\": constraints.positive,\n", " \"high\": right_extended_real,\n", " }\n", " reparametrized_params = [\"loc\", \"scale\", \"high\"]\n", "\n", " def __init__(self, loc=0.0, scale=1.0, high=float(\"inf\"), validate_args=True):\n", " batch_shape = lax.broadcast_shapes(\n", " jnp.shape(loc),\n", " jnp.shape(scale),\n", " jnp.shape(high),\n", " )\n", " self.loc, self.scale, self.high = promote_shapes(loc, scale, high)\n", " super().__init__(batch_shape, validate_args=validate_args)\n", "\n", " def log_prob(self, value):\n", " log_m = norm.logcdf(self.high, self.loc, self.scale)\n", " log_p = norm.logpdf(value, self.loc, self.scale)\n", " return jnp.where(value < self.high, log_p - log_m, -jnp.inf)\n", "\n", " def sample(self, key, sample_shape=()):\n", " shape = sample_shape + self.batch_shape\n", " minval = jnp.finfo(jnp.result_type(float)).tiny\n", " u = random.uniform(key, shape, minval=minval)\n", " return self.icdf(u)\n", "\n", " def icdf(self, u):\n", " m = norm.cdf(self.high, self.loc, self.scale)\n", " return self.loc + self.scale * ndtri(m * u)\n", "\n", " @constraints.dependent_property\n", " def support(self):\n", " return constraints.less_than(self.high)" ] }, { "cell_type": "markdown", "id": "d4175f2e", "metadata": {}, "source": [ "Let's try it out!" ] }, { "cell_type": "code", "execution_count": 18, "id": "d3792cd1", "metadata": {}, "outputs": [], "source": [ "def truncated_normal_model(num_observations, x=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal())\n", " scale = numpyro.sample(\"scale\", dist.LogNormal())\n", " high = numpyro.sample(\"high\", dist.Normal())\n", " with numpyro.plate(\"observations\", num_observations):\n", " numpyro.sample(\"x\", RightTruncatedNormal(loc, scale, high), obs=x)" ] }, { "cell_type": "code", "execution_count": 19, "id": "fcc1b782", "metadata": {}, "outputs": [], "source": [ "num_observations = 1000\n", "num_prior_samples = 100\n", "prior = Predictive(truncated_normal_model, num_samples=num_prior_samples)\n", "prior_samples = prior(PRIOR_RNG, num_observations)" ] }, { "cell_type": "markdown", "id": "e7ff52c5", "metadata": {}, "source": [ "As before, we run mcmc against some synthetic data.\n", "We select any random sample from the prior as the ground truth:" ] }, { "cell_type": "code", "execution_count": 20, "id": "9483511e", "metadata": {}, "outputs": [], "source": [ "true_idx = 0\n", "true_loc = prior_samples[\"loc\"][true_idx]\n", "true_scale = prior_samples[\"scale\"][true_idx]\n", "true_high = prior_samples[\"high\"][true_idx]\n", "true_x = prior_samples[\"x\"][true_idx]" ] }, { "cell_type": "code", "execution_count": 21, "id": "e471c6ff", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEGCAYAAACevtWaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAV8ElEQVR4nO3dfZiVdZ3H8c9HxQd8QIQJCccd2SVKKScDM9pSVyzEVVIzcVdScYMKdmMvrUXUsHxoWmF9WFc2QMKkDUxpgSSKYQO3JRNsJ8VcNpbFBS4eBnwOU9Hv/jG3u+eigRmY8+M3c+b9uq655pzffZ/P+Ub65ed37nOPI0IAgMpyUO4CAADlR3MHgApEcweACkRzB4AKRHMHgAp0SO4CJKlnz55RU1OTuwwA6FCefPLJ7RFR1dyxdtHca2pqtGrVqtxlAMABtWHDBklSdXX1fr3e9nN7OtYumjsAdEYjR46UJC1btqzs2TR3AMjkxhtvTJZNcweATIYMGZIsm6tlACCTdevWad26dUmy2bkDQCajRo2SlGnmbrta0nck9ZIUkqZFxN22b5b0OUmNxakTI2JR8ZrrJV0j6S1JfxURPy575QDQwX3ta19Llt2anfsuSddGxC9tHy3pSdtLimN3RsTk0pNtnyxphKRTJL1bUr3t90TEW+UsHAA6ujPPPDNZdosz94jYHBG/LB6/IulZSX328pLhkuZExOsR8d+S1ko6vRzFAkAlWbNmjdasWZMke59+oGq7RtIHJf2iWBpn+ynbM213L9b6SNpQ8rKNauYvA9ujba+yvaqxsXH3wwBQ8caMGaMxY8YkyW71D1RtHyXpEUnjI+Jl21Ml3aKmOfwtkqZIGtXavIiYJmmaJA0cOJDfGAJANRMezfK+6+vOz/K+t99+e7LsVjV3213U1Ni/GxHzJCkitpYcny7ph8XTTZJKP0t7QrEGoAPI1WA7o8GDByfLbnEsY9uS7pf0bET8Xcl675LTLpK0uni8QNII24fZPklSP0lPlK9kAKgMq1ev1urVq1s+cT+0Zuf+UUkjJT1tu6FYmyjpctu1ahrLrJc0RpIi4hnbD0n6tZqutBnLlTIA8PvGjRsnKdN17hHxM0lu5tCivbzmNkm3taEuAKh4d9xxR7JsPqEKAJkMGjQoWTb3lgGATBoaGtTQ0JAkm507AGQyfvx4SdzPHQAqyl133ZUsm+YOAJnU1tYmy2bmDgCZrFy5UitXrkySzc4dADL58pe/LImZOwBUlHvvvTdZNs0dADIZMGBAsmxm7gCQyYoVK7RixYok2ezcASCTiRMnSmLmDgAV5Vvf+laybJo7AGTSv3//ZNnM3AEgk+XLl2v58uVJstm5A0AmkyZNksTMHQAqysyZM5Nl09wBIJO+ffsmy2bmDgCZ1NfXq76+Pkk2O3cAyOTWW2+VJA0ZMqTs2TR3AMjkwQcfTJZNcweATKqrq5NlM3MHgEwWL16sxYsXJ8lm5w4AmdTV1UmShg4dWvZsmjsAZDJnzpxk2TR3AMjk+OOPT5bNzB0AMlm4cKEWLlyYJJudOwBkMmXKFEnSBRdcUPZsmjsAZPLwww8ny6a5A0AmPXv2TJbNzB0AMpk3b57mzZuXJJudOwBkcs8990iSLr744rJn09wBIJP58+cny25xLGO72vZPbf/a9jO2v1SsH2d7ie3fFN+7F+u2fY/ttbafsn1asuoBoAPr1q2bunXrliS7NTP3XZKujYiTJZ0haaztkyVNkLQ0IvpJWlo8l6TzJPUrvkZLmlr2qgGgAsydO1dz585Nkt1ic4+IzRHxy+LxK5KeldRH0nBJDxSnPSDpU8Xj4ZK+E00el3Ss7d7lLhwAOrqpU6dq6tQ0+999mrnbrpH0QUm/kNQrIjYXh7ZI6lU87iNpQ8nLNhZrm0vWZHu0mnb2OvHEE/e1bgDo8BYtWpQsu9WXQto+StIjksZHxMulxyIiJMW+vHFETIuIgRExsKqqal9eCgAVoWvXruratWuS7FY1d9td1NTYvxsR71yUufWdcUvxfVuxvklS6R3oTyjWAAAlZs+erdmzZyfJbs3VMpZ0v6RnI+LvSg4tkHRl8fhKSfNL1j9bXDVzhqSXSsY3AIDCjBkzNGPGjCTZrZm5f1TSSElP224o1iZKqpP0kO1rJD0n6TPFsUWShklaK2mnpKvLWTAAVIolS5Yky26xuUfEzyR5D4fPaeb8kDS2jXUBQMXr0qVLsmzuLQMAmcyaNUuzZs1Kkk1zB4BMUjZ37i0DAJksW7YsWTY7dwCoQDR3AMhk+vTpmj59epJsmjsAZJLyxmHM3AEgk/r6+mTZ7NwBoALR3AEgk/vuu0/33XdfkmyaOwBksnDhQi1cuDBJNjN3AMjkRz/6UbJsdu4AUIFo7gCQyd1336277747STbNHQAyWbp0qZYuXZokm5k7AGSyYMGCZNns3AGgAtHcASCTyZMna/LkyUmyGcsAQCY///nPk2XT3AEgk0ceeSRZNmMZAKhANHcAyKSurk51dXVJshnLAEAmDQ0NybJp7gCQyZw5c5JlM5YBgArEzh1Ap1cz4dEs7/viv31P136iv2666aayZ9PcASCTXc9v0po1abIZywBAJj0vuE6zZ89Okk1zB4AKRHMHgExe/NfZ+upXv5okm5k7AGSy6+Xt2rDh8CTZ7NwBIJOe54/Xt7/97STZ7NyBdirX5XmoDOzcASCTF5bP0vXXX58ku8Xmbnum7W22V5es3Wx7k+2G4mtYybHrba+1vcb2J5NUDQAV4O3XXtGOHTuSZLdm5z5L0tBm1u+MiNria5Ek2T5Z0ghJpxSvuc/2weUqFgAqSY+hf6lp06YlyW6xuUfEY5Keb2XecElzIuL1iPhvSWslnd6G+gAA+6EtM/dxtp8qxjbdi7U+kjaUnLOxWPs9tkfbXmV7VWNjYxvKAICO6YV/uV/XXXddkuz9be5TJf2hpFpJmyVN2deAiJgWEQMjYmBVVdV+lgEAHdfbu97Qa6+9liR7vy6FjIit7zy2PV3SD4unmyRVl5x6QrEGANhNj098Qf9Qd36S7P3audvuXfL0IknvXEmzQNII24fZPklSP0lPtK1EAMC+as2lkN+T9HNJ/W1vtH2NpL+1/bTtpySdLemvJSkinpH0kKRfS1osaWxEvJWsegDowJ6vn6bx48cnyW5xLBMRlzezfP9ezr9N0m1tKQoA0DbcfgAAMjluyGjd1Z5m7gCA9o3mDgCZ7PjJVI0dOzZJNs0dADI56JBDdcQRRyTJZuYOAJl0/5NrNJmZOwCgtWjuAJDJjsV/r9GjRyfJprkDQCYHHXG0evTokSSbmTsAZNL9zKv0jW8wcwcAtBLNHQAy2f7oXbr66quTZNPcASCTQ47pqerq6pZP3J/sJKkAgBYd+7Er9PWvM3MHALQSzR0AMtm+cLKuuOKKJNk0dwDI5JDj+qh///5pspOkAgBadOxHL9dNNzFzBwC0Es0dADJpnP9NjRgxIkk2YxkAyOTQXn1VW/veJNk0dwDIpNsZl2rCBGbuAIBWorkDQCaNP7hdl1xySZJsxjIAkMmh736vPvKR9yXJprkDQCbdPnyxrruOmTsAoJVo7gCQybZHvq4LL7wwSTZjGQDI5PA/OFXnnHNKkmyaOwBkcszA4frSl5i5AwBaieYOAJlsfWiSzjvvvCTZjGUAIJOuf3S6LrhgQJLsFnfutmfa3mZ7dcnacbaX2P5N8b17sW7b99hea/sp26clqRoAKsDRp52vL37xi0myWzOWmSVp6G5rEyQtjYh+kpYWzyXpPEn9iq/RkqaWp0wAwL5ocSwTEY/Zrtltebiks4rHD0haJulvivXvRERIetz2sbZ7R8TmslUMHGA1Ex7NXQIq1NY5N2jIqjtVX19f9uz9nbn3KmnYWyT1Kh73kbSh5LyNxRrNHQB20/W9H9Nll3wgSXabf6AaEWE79vV1tkeraXSjE088sa1lAECHc3TtUH3uc+3rOvettntLUvF9W7G+SVJ1yXknFGu/JyKmRcTAiBhYVVW1n2UAAJqzv819gaQri8dXSppfsv7Z4qqZMyS9xLwdAJq35Z8m6KyzzkqS3eJYxvb31PTD0562N0qaJKlO0kO2r5H0nKTPFKcvkjRM0lpJOyVdnaBmAKgIR71/iK669NQk2a25WubyPRw6p5lzQ9LYthYFAJ3BUe8foquual8zdwBAG8Vbu/Tmm28myaa5A0AmW+feqHPPPTdJNveWAYBMjjr1k/qLy2qTZNPcASCTo045W1dcwcwdACrK22/+Tjt37kySTXMHgEy2ff9mDRs2LEk2YxkAyOToDw7TF/4szZ3Rae4AkMmR7/u4LruMmTsAVJS3X/+tXnrppSTZNHcAyGTbI7do+PDhSbIZywBAJsd86EL91cgPJcmmuQNAJl37D9bFFzNzB4CK8tbOl7R9+/Yk2TR3AMik8Z+/oU9/+tNJshnLAEAmx5x+ka69clCSbJo7AGTS9Y8+rAsuYOYOABXlrVdf0JYtW5Jk09wBIJPGBd/UiBEjkmQzlgGATLqdcakmjDo9STY7dwDI5Ii+H9LQoUOTZNPcASCTXS83asOGDUmyae4AkMn2H07RyJEjk2QzcweATLoNHqEb/+LDSbLZuQNAJkfU1GrIkCFJsmnuAJDJmy9u0bp165Jk09wBIJMdi+7SqFGjkmQzcweATI794z/X18Z8JEk2O3cAyOTwE9+vM888M0k2zR0AMnlzx0atWbMmSTbNHQAy2fHjezVmzJgk2czcASCTYz9+pW7/4uAk2ezcASCTw094nwYPprkDQEV5o3G9Vq9enSS7TWMZ2+slvSLpLUm7ImKg7eMkzZVUI2m9pM9ExAttKxOdXc2ER3OXAJTd80v+UeP+6/tatmxZ2bPLsXM/OyJqI2Jg8XyCpKUR0U/S0uI5AGA33c8epTvuuCNJdoqxzHBJDxSPH5D0qQTvAQAd3mG936NBg9L8guy2NveQ9BPbT9oeXaz1iojNxeMtkno190Lbo22vsr2qsbGxjWUAQMfzxtZ1amhoSJLd1ub+xxFxmqTzJI21/fHSgxERavoL4PdExLSIGBgRA6uqqtpYBgB0PM8vnabx48cnyW7TD1QjYlPxfZvtH0g6XdJW270jYrPt3pK2laFOAKg4x50zWnd96WNJsvd75277SNtHv/NY0ickrZa0QNKVxWlXSprf1iIBoBId2quvamtrk2S3ZSzTS9LPbP9K0hOSHo2IxZLqJJ1r+zeShhTPAQC7eX3zf2rlypVJsvd7LBMR6ySd2sz6DknntKUoAOgMXvjpTH35uflJrnPn3jIAkMlx535e9/41t/wFgIpyaFWNBgwYkCSb5g4Amfxu47NasWJFkmyaOwBk8uJjD2jixIlJspm5A0AmPT45Tt+67qwk2ezcASCTLj1OUP/+/ZNk09wBIJPf/c/TWr58eZJsmjsAZPLiz76rSZMmJclm5g4AmfQYNl4zv3J2kmx27gCQSZdjj1ffvn2TZNPcASCT19Y3qL6+Pkk2zR0AMnlpxRzdeuutSbJp7gCQSc8/vVYPPvhgkmyaOwBkcsgxVaqurk6STXMHgExeW/ekFi9enCSb5g4Ambz0+PdVV5fm9xnR3AEgk6oL/0Zz5sxJks2HmLBPaiY8mrsEoGIcfFR3HX/88Umy2bkDQCY71/5CCxcuTJJNcweATF5+4geaMmVKkmyaOwBkUvWp6/Xwww8nyaa5A0AmB3ftpp49eybJprkDQCY716zQvHnzkmTT3AEgk5efXKB77rknSTbNHQAyedclN2n+/PlJsmnuAJDJQYcdqW7duqXJTpIKAGjRb599THPnzk2STXMHgExe+fdFmjp1apJsmjsAZPKuS2/WokWLkmTT3AEgk4O6HK6uXbumyU6SCgBo0avP/FSzZ89Okk1zB4BMXv3VjzVjxowk2TR3AMik12W3asmSJUmyk93P3fZQSXdLOljSjIhI8+tGOiHuqQ5UBh98iLp06ZIkO8nO3fbBkv5B0nmSTpZ0ue2TU7wXAHRUrz5dr1mzZiXJTrVzP13S2ohYJ0m250gaLunX5X6jnLvY9XXnZ3tvAB3fq0/Xa9Zv/11XXXVV2bNTNfc+kjaUPN8o6cOlJ9geLWl08fRV22sS1bI3PSVt398X+5tlrGTv2lTnAUSd5dVR6pQ6Tq3trs7lGyTbuy+3ts4/2NOBbL9DNSKmSZqW6/0lyfaqiBiYs4bWoM7yos7y6yi1dqY6U10ts0lSdcnzE4o1AMABkKq5r5TUz/ZJtg+VNELSgkTvBQDYTZKxTETssj1O0o/VdCnkzIh4JsV7tVHWsdA+oM7yos7y6yi1dpo6HRHlKAQA0I7wCVUAqEA0dwCoQJ2+udv+S9v/YfsZ23+bu549sX2z7U22G4qvYblr2hvb19oO2z1z19Ic27fYfqr4s/yJ7Xfnrqk5tu8o/vl8yvYPbB+bu6bm2L60+Hfobdvt7lJD20Ntr7G91vaE3PXsie2ZtrfZXt3WrE7d3G2fraZPzp4aEadImpy5pJbcGRG1xVeaO/yXge1qSZ+Q9D+5a9mLOyLiAxFRK+mHkr6auZ49WSJpQER8QNJ/Sro+cz17slrSxZIey13I7jrY7VBmSRpajqBO3dwlfUFSXUS8LkkRsS1zPZXiTklfkdRuf1ofES+XPD1S7bTWiPhJROwqnj6ups+MtDsR8WxE5PiUeWv83+1QIuINSe/cDqXdiYjHJD1fjqzO3tzfI+ljtn9he7ntQbkLasG44j/PZ9runruY5tgeLmlTRPwqdy0tsX2b7Q2S/lztd+deapSkH+UuogNq7nYofTLVcsBku/3AgWK7XtLxzRy6QU3/+4+TdIakQZIest03Ml0f2kKtUyXdoqYd5i2SpqjpX/YDroU6J6ppJJPd3uqMiPkRcYOkG2xfL2mcpEkHtMBCS3UW59wgaZek7x7I2kq1pk60HxXf3CNiyJ6O2f6CpHlFM3/C9ttqumFP44Gqr9Teai1le7qa5sRZ7KlO2++XdJKkXxU3QjpB0i9tnx4RWw5giZJa/+eppoa5SJmae0t12r5K0p9KOifXxkPapz/P9qZT3g6ls49l/lnS2ZJk+z2SDlU7u2PcO2z3Lnl6kZp+gNWuRMTTEfGuiKiJiBo1/efvaTkae0ts9yt5OlzSf+SqZW+KX3rzFUkXRsTO3PV0UJ3ydiid+hOqxf/RMyXVSnpD0nUR8S9Zi9oD2w+qqc6QtF7SmIjYnLOmltheL2lgRLS7vzBtPyKpv6S3JT0n6fMR0e52c7bXSjpM0o5i6fGI+HzGkppl+yJJfy+pStKLkhoi4pNZiypRXDp8l/7/dii35a2oeba/J+ksNU0QtkqaFBH371dWZ27uAFCpOvtYBgAqEs0dACoQzR0AKhDNHQAqEM0dACoQzR0AKhDNHQAqEM0daIbtQcVN2g63fWRxr/IBuesCWosPMQF7YPtWSYdLOkLSxoj4RuaSgFajuQN7UNyeYqWk30kaHBFvZS4JaDXGMsCe9ZB0lKSj1bSDBzoMdu7AHtheoKbf2nOSpN4RMS5zSUCrVfz93IH9Yfuzkt6MiH8qfgfnCtt/0l7vGgrsjp07AFQgZu4AUIFo7gBQgWjuAFCBaO4AUIFo7gBQgWjuAFCBaO4AUIH+F+x+2DH4W0dYAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.hist(true_x.copy())\n", "plt.axvline(true_high, linestyle=\":\", color=\"k\")\n", "plt.xlabel(\"x\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "235cd8f4", "metadata": {}, "source": [ "Run MCMC and check the estimates:" ] }, { "cell_type": "code", "execution_count": 22, "id": "d3a979e5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1850.91it/s, 15 steps of size 8.88e-02. acc. prob=0.88]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7434.51it/s, 5 steps of size 1.56e-01. acc. prob=0.78]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7792.94it/s, 54 steps of size 5.41e-02. acc. prob=0.91]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7404.07it/s, 9 steps of size 1.77e-01. acc. prob=0.78]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " high 0.88 0.01 0.88 0.88 0.89 590.13 1.01\n", " loc -0.58 0.07 -0.58 -0.70 -0.46 671.04 1.01\n", " scale 1.40 0.05 1.40 1.32 1.48 678.30 1.01\n", "\n", "Number of divergences: 6310\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "mcmc = MCMC(NUTS(truncated_normal_model), **MCMC_KWARGS)\n", "mcmc.run(MCMC_RNG, num_observations, true_x)\n", "mcmc.print_summary()" ] }, { "cell_type": "markdown", "id": "d57f0535", "metadata": {}, "source": [ "Compare estimates against the ground truth:" ] }, { "cell_type": "code", "execution_count": 23, "id": "dedf4f61", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True high : 0.88\n", "True loc : -0.56\n", "True scale: 1.45\n" ] } ], "source": [ "print(f\"True high : {true_high:3.2f}\")\n", "print(f\"True loc : {true_loc:3.2f}\")\n", "print(f\"True scale: {true_scale:3.2f}\")" ] }, { "cell_type": "markdown", "id": "f4a1e5e5", "metadata": {}, "source": [ "Note that, even though we can recover good estimates for the true values,\n", "we had a very high number of divergences. These divergences happen because\n", "the data can be outside of the support that we are allowing with our priors.\n", "To fix this, we can change the prior on `high` so that it depends on the observations:" ] }, { "cell_type": "code", "execution_count": 24, "id": "894a68e7", "metadata": {}, "outputs": [], "source": [ "def truncated_normal_model_2(num_observations, x=None):\n", " loc = numpyro.sample(\"loc\", dist.Normal())\n", " scale = numpyro.sample(\"scale\", dist.LogNormal())\n", " if x is None:\n", " high = numpyro.sample(\"high\", dist.Normal())\n", " else:\n", " # high is greater or equal to the max value in x:\n", " delta = numpyro.sample(\"delta\", dist.HalfNormal())\n", " high = numpyro.deterministic(\"high\", delta + x.max())\n", "\n", " with numpyro.plate(\"observations\", num_observations):\n", " numpyro.sample(\"x\", RightTruncatedNormal(loc, scale, high), obs=x)" ] }, { "cell_type": "code", "execution_count": 25, "id": "9161babf", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:03<00:00, 1089.76it/s, 15 steps of size 4.85e-01. acc. prob=0.93]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 8802.95it/s, 7 steps of size 5.19e-01. acc. prob=0.92]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 8975.35it/s, 3 steps of size 5.72e-01. acc. prob=0.89]\n", "sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 8471.94it/s, 15 steps of size 3.76e-01. acc. prob=0.96]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " delta 0.01 0.01 0.00 0.00 0.01 6104.22 1.00\n", " high 0.88 0.01 0.88 0.88 0.89 6104.22 1.00\n", " loc -0.58 0.08 -0.58 -0.71 -0.46 3319.65 1.00\n", " scale 1.40 0.06 1.40 1.31 1.49 3377.38 1.00\n", "\n", "Number of divergences: 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "mcmc = MCMC(NUTS(truncated_normal_model_2), **MCMC_KWARGS)\n", "mcmc.run(MCMC_RNG, num_observations, true_x)\n", "mcmc.print_summary(exclude_deterministic=False)" ] }, { "cell_type": "markdown", "id": "f8ceb53d", "metadata": {}, "source": [ "And the divergences are gone." ] }, { "cell_type": "markdown", "id": "d5454e89", "metadata": {}, "source": [ "In practice, we usually want to understand how the data\n", "would look like without the truncation. To do that in NumPyro,\n", "there is no need of writing a separate model, we can simply\n", "rely on the `condition` handler to push the truncation point to infinity:" ] }, { "cell_type": "code", "execution_count": 26, "id": "3a3c6831", "metadata": {}, "outputs": [], "source": [ "model_without_truncation = numpyro.handlers.condition(\n", " truncated_normal_model,\n", " {\"high\": float(\"inf\")},\n", ")\n", "estimates = mcmc.get_samples().copy()\n", "estimates.pop(\"high\") # Drop to make sure these are not used\n", "pred = Predictive(\n", " model_without_truncation,\n", " posterior_samples=estimates,\n", ")\n", "pred_samples = pred(PRED_RNG, num_observations=1000)" ] }, { "cell_type": "code", "execution_count": 27, "id": "6139d303", "metadata": {}, "outputs": [], "source": [ "# thin the samples for a faster histogram\n", "samples_thinned = pred_samples[\"x\"].ravel()[::1000]" ] }, { "cell_type": "code", "execution_count": 28, "id": "28877060", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABVXUlEQVR4nO3de5xO5f7/8denMZKUhI4U7Y1yHIxD2UKIInLoS6WoXXYHyW+3RWVTOmx7d9hqp3OyK2VKOVSqjagkOe0pIoVG6OSUHGO4fn+sNbNvY2bu28zc91o37+fjcT+611rXWtdn3ab7uj73uta1zDmHiIiIiIiIhN9RQQcgIiIiIiIisVECJyIiIiIikiSUwImIiIiIiCQJJXAiIiIiIiJJQgmciIiIiIhIklACJyIiIiIikiSUwImEmJm1NrN1QcdxqMzsKTP7a9BxiIhI4czsSjP7T9BxlBQzm21m1wUdx6Ews3Fmdl+MZauZmTOzUvGOS8JLCZwEzv8i+n2edXeb2csx7h/zF1+ihTm2gpRE4+ecu8E5d29JxSQiEnZmdoWZLTSz7Wb2g5m9a2Z/CDquaJxz451zFyayTjPLMrN2iazzSJSsPwJLdErg5LCnX6kSy8xSirGv/q1EJOmY2Z+B0cADwMnAGcATQNcAw4pK37lFo89NgqYETkIv5xckM7vNzH72f9m8xt/WH7gSuN3/1fMtf32WmQ0xsy+AHWZWKu+VvsirY4XV4W8/xsweNrM1ZrbVzOaY2TH+ttfN7Ed//UdmVidKbKeZ2RtmtsHMvjWzgXnqGWdmW8xsGdAkymfjzGygma02s41m9qCZHeVvO8rMhvkx/2xmL5pZeX9bGTN72cw2mdkvZrbAzE42s/uBlsDjfsyP++XPNrPpZrbZzFaY2f/l+RyfNLNpZrYDaJP3yqOZXW9mK/39p5rZaXnO4WYz+wb4Jta/CxGRMPC/V0cCNzvn3nTO7XDO7XXOveWcG+yXOdrMRpvZ9/5rtJkd7W/LaX9uj2h/LjWzi83sa/97886I+u42s4lmlmFm28xssZk1iNg+1MxW+duWmVm3iG39zOwTM/unmW0C7vbXzfG3m7/tZzP71cyWmFndnPP025ENfrsyLKK96ee3iw/57de3ZnZRAZ/XS3gJ7lt+O3O7v76LmX3pt0mzzeycQj7z8/x2a6v/3/PyFPmdmc33z2GKmZ3o75dv2xdxfs/7n/96M7vP/B8k8/nc7vX3rxsRU2Uz22VmJ/nLnc0s0y8318zqR5Rt6P+7bTOzDKBMIeea4n+uG81sNdApz/ZrzGy5f6zVZvYnf/2xwLvAaf7nvN28/kdTM/vUj+sHM3vczEoXVL+ElHNOL70CfQEO+H2edXcDL/vvWwPZeA1kKnAxsBOo4G8fB9yXZ/8sIBOoChyTXz2R+8VQxxhgNnA6kAKcBxztb7sWOA44Gu8X2Mz86vCXjwIWAcOB0sBZwGqgg799FPAxcKIf+1JgXZTPbpZf/gzga+C6iLhW+nWUA94EXvK3/Ql4Cyjrn09j4Hh/2+ycY/jLxwJrgWuAUkBDYCNQO+IctwIt/PMrk+ezvcAv38j/jP4FfJTnHKb753BM0H+Peumll16H8gI6+u1HqULKjATmAScBlYG5wL3+tpz2Z7jf/lwPbABe8duWOsAuoLpf/m5gL9DTL/8X4Fsg1d9+GXCa/33cC9gBnOpv6+fXdYv/fX6Mv26Ov72D30adABhwTsS+LwJT/Jiq+e3NHyOOu9ePPQW4EfgesAI+jyygXcRyTT/O9v453Y7XfpXOZ98TgS3AVf45XO4vV/S3zwbWA3Xx2q83+F9/orC2bxLwtL/PScB84E+FfG5jgfsj4roZeM9/3xD4GWjm19PXP+ej8dr+NcD/88+1p//Z3VfAZ3UD8BVen+BEvDbf4f+94SV0v/P/vVrh9V0aRfxtrctzvMZAc/88qgHLgUFB/3+k1yF+7wQdgF56EVsCt4uIxtH/Ymzuvx+X94vP/6K8trB6ODiBy7cOvEZwF9AghnM5wa+nfH6x+V/m3+XZ5w7gBf/9aqBjxLb+eb988zmnyPI3ATP99zOBmyK21fIbiVJ4yd1coH4+x5zNgQlcL+DjPGWeBkZEnOOLebZHfrbPA/+I2FbOj6NaxDlcEPTfoV566aVXUV54Iy1+jFJmFXBxxHIHIMt/n9P+pPjLx/nfi80iyi8CLvXf3w3Mi9h2FPAD0LKAujOBrv77fvm0Qf34XwJ3AV5i1hw4KqJMCrAH/4c7f92fgNkRx1gZsa2sfw6nFBBTFgcmcH8FXstzTuuB1vnsexUwP8+6T4F+/vvZwKiIbbX92FMKavvwhr3+RsSPiHiJ4axCPrd2wKqI5U+Aq/33T+In6BHbV+AlWOeTJ7n1YyoogfsAuCFi+UIiErh8yk8Gbo342yqwD+GXGQRMCvr/I70O7aUxvBIG+/B+hYqUitfJz7HJOZcdsbwTLxEozNpDjKOgOirhXVValXcHf3jF/Xi/eFYG9vubKuFdlcrrTLzhDL9ErEvBu+oG3q+mkXGviSHuvOVzhieelmf/NXjJ28nAS3i/5k0wsxOAl4G7nHORn3lkzM3yxFzKP0Z+MeR1GrA4Z8E5t90fgnI6XiMebX8RkTDbBFQys1J52pBI+X0fnxaxvMk5t89/v8v/708R23dxYJuX+53pnNtv3kQVpwGY2dXAn/GursD/2rGD9s3LOfeBeUPnxwBnmtmbeFf4jsFrl/Oew+kRyz9GHGenmeXUHYsDPh//nNbmOX6+ZQuIJW+7mIr3GeTb9uG1c6nAD37c4CWRkcfJ+7nNAsqaWTO8f6s0vKt4+Mfra2a3RJQv7cfugPXOz54iYixIof0Cf6jqCLyrmEfhJc9LCjqYmdUEHgHS/bKl8H4gkCSie+AkDL7jfw1NjurElryA92UYy/qdeF9WOU6J8fgbgd14QxTyugLvJvV2QHn+dx45LUDeGNYC3zrnToh4Heecu9jf/gNe45LjjBjiy1v+e//993iNSOS2bOAn592fcY9zrjbecNDOwNWFxPxhnpjLOedujChT0L/BQXH44/Ir4v26Gsv+IiJh9ine1ZtLCymT3/fx9wWUjUXu975/H1oV4HszOxN4FhiAN6TwBLyh+Baxb6Hft865x5xzjfGuXNUEBuO1g3vzOYf1Bx8hJnljyNtOGN455nf8vJ9lfrHkbRf3AhsLafvW4v0bVopo5453ztUpKGY/4X4N70rd5cDbzrlt/ua1eMMrI9vNss65V/Ha+dMtIlOk8La+wH6BefdRvgE8BJzs/3tPo+A+CHhXB78Cajjnjgfu5MC/D0kCSuAkDDKAYWZWxbyJN9oBlwATY9z/J7z7vKLJBK7wbwjuiDeUISrn3H68se6P+DcAp5jZuf4X53F4X/qb8JLDB6LENh/YZt4EK8f4x6prZjmTlbwG3GFmFcysCt54+2gG++WrArfifZ4ArwL/z8yqm1k5P7YM51y2mbUxs3r+FcRf8Rq3nKuHeWN+G6hpZleZWar/alLYDeZ5vApcY2Zp/mf2APCZcy4rxv1FRELLObcV7/61MeZNPlLW/568yMz+4Rd7Fa+dq2xmlfzyMT0qpwCNzay7ebMhDsJrh+bh3b/l8O6hw7zJuOoWdJC8/O/2ZmaWindP2m5gf0Sycr+ZHecnin8uxjnkbWdeAzqZWVu/7tv8c5qbz77T8NqkK8yboKwXXrL5dkSZPmZW28zK4t1/ONE5t6+gts859wPwH+BhMzve74v8zsyi9RNewbvN4Er/fY5ngRv8z9LM7Fgz62Rmx+El/NnAQP/vpDvQtJA6XvPLVjGzCsDQiG2l8e6r2wBk+1fjIh8J8RNQ0fwJzHzH+ee+3czOxrtfUZKMEjgJg5F4X9Jz8G5E/gdwpXNuaYz7Pw/U9mdUmlxIuVvxEsNf8L5sCyub11/whiQsADYDf8f7/+dFvCuF64FleA1ogbH5jWBnvKEW3+L9qvkc3tU7gHv8432L15i8RHRT8IY/ZALv+HWCl3S+BHzkH283/0sIT8FLkH/Fu4H5w4i6HgV6mjeT2GP+L4oXAr3xfvn80T//o2OIDefcDLz7G97A+yXxd/6xREQOC865h/ESmmF4nem1eFfBJvtF7gMWAl/gtSWL/XVFNQUvcdiCd09Yd//q0jLgYbwk4SegHt69WbE6Hi/52ILXFm0CHvS33YKX1K3Ga69fwWtniuJveAntL2b2F+fcCqAP3iRXG/Ha6kucc3vy7uic24TXjt7mx3c70Nk5tzGi2Et492L/iHcLRM5sz4W1fVfjJUTL/POfCJxa2Ek45z7D+0xOw5vxMWf9QrwJXR73j7US7z46/HPq7i9vxvt3fLOQap4F3gc+x/u7yS3rt88D8ZK8LXijgqZGbP8K78eD1f5nfRpef+YKYJt/7JwffSWJ2IFDcEUkmZiZwxsGsTLoWEREJP7M7G68Cbn6BB2LiARDV+BERERERESShBI4ERERERGRJKEhlCIiIiIiIklCV+BERERERESShBI4ERERERGRJFEqlkL+M7MeBVKA55xzo/JsvwG4GdgHbAf6+1PZYmZ3AH/0tw10zr1fWF2VKlVy1apVO8TTEJF4WrPGe6b6mWfmfXaqSPEsWrRoo3OuctBxJAu1kSIi4VZSfabC2seoCZz/sMMxQHtgHbDAzKbmJGi+V5xzT/nluwCPAB3NrDbe857q4D0jY4aZ1fSfhZWvatWqsXDhwhhPTUQS4Y477gDgb3/7W8CRyOHGzNYEHUMyURspIhJuJdVnKqx9jOUKXFNgpXNutX+wCUBXvAcdAuCc+zWi/LFAzswoXYEJzrnfgG/NbKV/vE8P6QxEJFBK3ERERESiS0SfKZYE7nRgbcTyOqBZ3kJmdjPwZ7yn2F8Qse+8PPuens++/YH+AGeccUYscYuIiIiIiBxxSmwSE+fcGOfc74AhwLBD3PcZ51y6cy69cmXdCiESNtdccw3XXHNN0GGIiIiIhFoi+kyxXIFbD1SNWK7iryvIBODJIu6br71797Ju3Tp27959qLuK5CpTpgxVqlQhNTU16FCSTtWqVaMXEpFAqI2UZKa2WQ43iegzxZLALQBqmFl1vOSrN3BFZAEzq+Gc+8Zf7ATkvJ8KvGJmj+BNYlIDmH+oQa5bt47jjjuOatWqYWaHursIzjk2bdrEunXrqF69etDhJJ2RI0cGHYJIKBV1lmYzqwYsB1b4Rec5524oSgxqIyVZqW2Ww1Ei+kxREzjnXLaZDQDex2ugxjrnvjSzkcBC59xUYICZtQP2AluAvv6+X5rZa3gTnmQDNxc2A2VBdu/erYZJisXMqFixIhs2bAg6FBE5TBRnlmZ/2yrnXFpx41AbKclKbbNI0cT0HDjn3DRgWp51wyPe31rIvvcD9xc1wBxqmKS49DdUdH369AHg5ZdfDjgSkVApzizNJUrfb5Ks9Lcrh5tE9JlKbBKTw9mmTZtIS0sjLS2NU045hdNPPz13ec+ePQmJ4ZdffuGJJ57IXf7+++/p2bNnQurO66mnnuLFF18stExmZibTpk0rtIwkj1q1alGrVq2gwxAJm/xmac5vpuWbzWwV8A9gYMSm6mb2XzP70MxaFlSJmfU3s4VmtjCMVyrURh5IbaTIkS0RfSZzLi4/BhZZenq6y/uQ0uXLl3POOecEFNGB7r77bsqVK8df/vKX3HXZ2dmUKhXTxcwiy8rKonPnzixdujSu9ZSUcePGsXDhQh5//PGgQzlAmP6WRATMbJFzLj3oOIrCzHoCHZ1z1/nLVwHNnHMDCih/BdDBOdfXzI4GyjnnNplZY2AyUCfPFbuDqI3Mn9rI5Bamv2GRsCisfdQVuCLq168fN9xwA82aNeP222/n7rvv5qGHHsrdXrduXbKyssjKyuKcc87h+uuvp06dOlx44YXs2rULgJUrV9KuXTsaNGhAo0aNWLVqFdu3b6dt27Y0atSIevXqMWXKFACGDh3KqlWrSEtLY/DgwWRlZVG3bl3Au//hmmuuoV69ejRs2JBZs2YBXgPRvXt3OnbsSI0aNbj99tvzPZdq1apx++23U69ePZo2bcrKlSsBr0G84IILqF+/Pm3btuW7774DOOBcW7duzZAhQ2jatCk1a9bk448/Zs+ePQwfPpyMjAzS0tLIyMiIw7+AiEjgijJL86UAzrnfnHOb/PeLgFVAzfiEmXhqI9VGikj8KIErhnXr1jF37lweeeSRQst988033HzzzXz55ZeccMIJvPHGGwBceeWV3HzzzXz++efMnTuXU089lTJlyjBp0iQWL17MrFmzuO2223DOMWrUKH73u9+RmZnJgw8+eMDxx4wZg5mxZMkSXn31Vfr27Zs7nXRmZiYZGRksWbKEjIwM1q5de1B8AOXLl2fJkiUMGDCAQYMGAXDLLbfQt29fvvjiC6688koGDhyY777Z2dnMnz+f0aNHc88991C6dGlGjhxJr169yMzMpFevXofysUoI9e7dm969ewcdhkjY5M7SbGal8WZpnhpZwMxqRCzmztJsZpX9SVAws7PwZmlenZCoE0RtpEdtpMiRJRF9pviOaYiT1q1b069fP/r168fevXtp37491113HX369GHnzp1cfPHF3HjjjfTq1YutW7fStWtXBg4cSPfu3dm4cSM9e/bktttu45JLLuHHH3/klFNOKVIcl112GSkpKVHLVa9enbS0NAAaN25MVlYW27ZtY/369XTr1g3wnoMC3vN87rzzTj766COOOuoo1q9fz08//VTo8efMmcMtt9wCwNlnn82ZZ57J119/DUDbtm0pX748ALVr12bNmjX5Pp/i8ssvz/3v//t//w+ATz/9lDfffBOAq666qsBfJ7t3737AuR1pqg19p0SOkzWqU4kcJx5y/n5F5H+KM0szcD4w0sz2AvuBG5xzm0siLrWRB1IbKRJns/4WdASeNncEHQGQmD5TUiZwYXHsscfmvi9VqhT79+/PXY58oOrRRx+d+z4lJSV3eEh+xo8fz4YNG1i0aBGpqalUq1atWA9nzVt3dnZ2vuUiZ4E61Bmhcuoo7PiS3IYOHRp0CCKhVNRZmp1zbwBvxDe6YKmNPLAOtZFSosKSNMlBEtFnSsoEbvbs2bnvU1NTD1guW7bsAcvly5c/YLlSpUoHLBf1l8W8qlWrxttvvw3A4sWL+fbbbwstf9xxx1GlShUmT57MpZdeym+//ca+ffvYunUrJ510EqmpqcyaNYs1a9bklt+2bVu+x2rZsiXjx4/nggsu4Ouvv+a7776jVq1aLF68OOb4MzIyGDp0KBkZGZx77rkAnHfeeUyYMIGrrrqK8ePH07JlgZOk5Xt+BcUrIiLxozbyQGojRQ7d6JlfF2v/QW0Pm1t6Q0n3wJWQHj16sHnzZurUqcPjjz9OzZrR/3BfeuklHnvsMerXr895553Hjz/+yJVXXsnChQupV68eL774ImeffTYAFStWpEWLFtStW5fBgwcfcJybbrqJ/fv3U69ePXr16sW4ceMO+FUxFlu2bKF+/fo8+uij/POf/wTgX//6Fy+88AL169fnpZde4tFHH435eG3atGHZsmW6Qfsw0aNHD3r06BF0GCKSpNRGHkhtpMjhKxF9Jj1GQKhWrRoLFy6kUqVKQYcSd/H4WzoS7oHLmVEtcmpwkZKQzI8RCILayMQ7ktrIoOhvuAjiPIQyKa/AheQeuJLqMxXWPiblEEoRSSwlbiIiIiLRJaLPpARONCuWiIhIAdRGikjY6B44EYmqS5cudOnSJegwREREREItEX0mXYETkajatm0bdAgiIiIioZeIPpMSOBGJ6tZb832UlYiIiIhESESfSUMoRUREREREkoQSuBhkZWVRt27dA9bdfffdudOEFiQzM5Np06bFM7SoHnjggUPeZ9y4cQwYMCAO0XiKEhPAddddx7Jly0o4GonFRRddxEUXXRR0GCISQps2bSItLY20tDROOeUUTj/99NzlPXv2HFD2qaee4sUXXwSgX79+TJw4MeZ6vvrqK9LS0mjYsCGrVq0qVsxZWVm88sorxTpGcY0ePZqdO3ce0j6zZ8+mc+fOcYqoaDEBDB8+nBkzZsQhIpHkk4g+U1IOoSyp527liNfztzIzM1m4cCEXX3zxQduys7MpVSr+H/8DDzzAnXfeGfd6DkVRYtq3bx/PPffcIe+TkpJySPtI/i655JKgQxCRGCW6jaxYsSKZmZmA9+NmuXLlCpxG+4YbbihyHJMnT6Znz54MGzasyMfIkZPAXXHFFQdtS1T7PHr0aPr06UPZsmXjXlesihLTvn37GDly5CHVo/b58JeUz5ErIYnoM+kKXAlo3bo1Q4YMoWnTptSsWZOPP/6YPXv2MHz4cDIyMkhLSyMjI4O7776bq666ihYtWnDVVVcddKWrc+fOzJ49G4By5cpx11130aBBA5o3b85PP/0EwE8//US3bt1o0KABDRo0YO7cuQBceumlNG7cmDp16vDMM88AMHToUHbt2kVaWhpXXnklAC+//DJNmzYlLS2NP/3pT+zbtw+AF154gZo1a9K0aVM++eSTfM8zJ/5zzz2XGjVq8OyzzwLgnGPw4MHUrVuXevXqkZGRAcAPP/zA+eefT1paGnXr1uXjjz8+pJjKlSvHbbfdRoMGDfj0009p3bo1OQ+wffXVV6lXrx5169ZlyJAhuTHm3UdKxk033cRNN90UdBgikiSeffZZmjRpQoMGDejRo0fuVZ1YR680b96c+vXr061bN7Zs2cK0adMYPXo0Tz75JG3atDlon3LlyuW+nzhxIv369QO8q3wDBw7kvPPO46yzzsq94jd06FA+/vhj0tLS+Oc//8m4cePo0qULF1xwAW3btj3oSteAAQMYN24c4D3Ye8SIETRq1Ih69erx1VdfAbB9+3auueYa6tWrR/369XnjjTcAuPHGG0lPT6dOnTqMGDECgMcee4zvv/+eNm3a5J7Pf/7zH84991waNWrEZZddxvbt2wF47733OPvss2nUqBFvvvlmvp/ZuHHj6Nq1K61bt6ZGjRrcc889udseeeQR6tatS926dRk9ejQAO3bsoFOnTjRo0IC6deuSkZFxSDFVq1aNIUOG0KhRI15//fUDrqbOnDmThg0bUq9ePa699lp+++23fPcROVwlos+UlFfgwig7O5v58+czbdo07rnnHmbMmMHIkSNZuHAhjz/+OOA1XMuWLWPOnDkcc8wxuY1Bfnbs2EHz5s25//77uf3223n22WcZNmwYAwcOpFWrVkyaNIl9+/blfpmOHTuWE088kV27dtGkSRN69OjBqFGjePzxx3N/GV2+fDkZGRl88sknpKamctNNNzF+/Hjat2/PiBEjWLRoEeXLl6dNmzY0bNgw37i++OIL5s2bx44dO2jYsCGdOnXi008/JTMzk88//5yNGzfSpEkTzj//fF555RU6dOjAXXfdxb59+9i5cyctW7aMKaarr76aHTt20KxZMx5++OEDYvj+++8ZMmQIixYtokKFClx44YVMnjyZSy+9tMB9kkFJ/Goer6vJIiKHonv37lx//fUADBs2jOeff55bbrklpn2vvvpq/vWvf9GqVSuGDx/OPffcw+jRo7nhhhsKvbpXkB9++IE5c+bw1Vdf0aVLF3r27MmoUaN46KGHePvttwEvAVq8eDFffPEFJ554Yu6PqQWpVKkSixcv5oknnuChhx7iueee495776V8+fIsWbIEgC1btgBw//33c+KJJ7Jv3z7atm3LF198wcCBA3nkkUeYNWsWlSpVYuPGjdx3333MmDGDY489lr///e888sgj3H777Vx//fV88MEH/P73v6dXr14FxjR//nyWLl1K2bJladKkCZ06dcLMeOGFF/jss89wztGsWTNatWrF6tWrOe2003jnHa/d2bp1K+XLl48ppuHDhwPeldfFixcDXpIJsHv3bvr168fMmTOpWbMmV199NU8++SSDBg06aB8RKTolcDEws6jru3fvDkDjxo0Lfehnly5dOOaYY6LWWbp06dxf/xo3bsz06dMB+OCDD3LvH0hJSaF8+fKA92vepEmTAFi7di3ffPMNFStWPOCYM2fOZNGiRTRp0gSAXbt2cdJJJ/HZZ5/RunVrKleuDECvXr34+uv8L3137dqVY445hmOOOYY2bdowf/585syZw+WXX05KSgonn3wyrVq1YsGCBTRp0oRrr72WvXv3cumll5KWlnbQ8QqKKef8evTocdA+CxYsOCDeK6+8ko8++ohLL720wH2keNq1awegexxEJCZLly5l2LBh/PLLL2zfvp0OHTrEtN/WrVv55ZdfaNWqFQB9+/blsssuK1Ysl156KUcddRS1a9fOHc2Sn/bt23PiiSfGdMzINj/nqtiMGTOYMGFCbpkKFSoA8Nprr/HMM8+QnZ3NDz/8wLJly6hfv/4Bx5s3bx7Lli2jRYsWAOzZs4dzzz2Xr776iurVq1OjRg0A+vTpkzvKJr/4c9r97t27M2fOHMyMbt26ceyxx+au//jjj+nYsSO33XYbQ4YMoXPnzrRs2fKg4xUUU478kskVK1ZQvXp1atb0hr/17duXMWPG5CZwhSWgIoeLRPSZlMDFoGLFirm/pOXYvHkz1atXz10++uijAS/pyM7OLvBYOV+iAKVKlWL//v25y7t37859n5qampsgRjvm7NmzmTFjBp9++illy5aldevWBxwrh3OOvn378re//e2A9ZMnTy7w2HnlTWYLSm4Bzj//fD766CPeeecd+vXrx5///GeuvvrqmGICKFOmzCGPkS/KPhKdGl0RORT9+vVj8uTJNGjQgHHjxkW9olVckW1R3vYvp30Gr80pSKztc+Qxo7XP3377LQ899BALFiygQoUK9OvXr8D2uX379rz66qsHrM8ZrRKLQ2mfa9asyeLFi5k2bRrDhg2jbdu2uVfWosWUI/LzilVR9hFJNonoM+keuBiUK1eOU089lQ8++ADwkrf33nuPP/zhD4Xud9xxx7Ft27YCt1erVo3MzEz279/P2rVrmT9/ftRY2rZty5NPPgl4NwFv3bqVrVu3UqFCBcqWLctXX33FvHnzcsunpqayd+/e3H0nTpzIzz//nHsea9asoVmzZnz44Yds2rSJvXv3Fjo2fcqUKezevZtNmzYxe/ZsmjRpQsuWLcnIyGDfvn1s2LCBjz76iKZNm7JmzRpOPvlkrr/+eq677rrcYROxxFSYpk2b8uGHH7Jx40b27dvHq6++mvtrrcTH9ddfnzscSkQkmm3btnHqqaeyd+9exo8fH/N+5cuXp0KFCnz88ccAvPTSSzF9v5988sksX76c/fv3545GKUy09vnMM89k2bJl/Pbbb/zyyy/MnDkz6jHbt2/PmDFjcpe3bNnCr7/+yrHHHkv58uX56aefePfdd/ONoXnz5nzyySesXLkS8G6j+Prrrzn77LPJysrKnXWzoGQKYPr06WzevJldu3YxefJkWrRoQcuWLZk8eTI7d+5kx44dTJo0iZYtW/L9999TtmxZ+vTpw+DBg3Pb51hiKkytWrXIysrK3SfWfz8peaNnfl2slxRdIvpMugIXoxdffJGbb76ZP//5zwCMGDGC3/3ud4Xu06ZNG0aNGkVaWhp33HHHQdtbtGhB9erVqV27Nueccw6NGjWKGsejjz5K//79ef7550lJSeHJJ5+kY8eOPPXUU5xzzjnUqlWL5s2b55bv378/9evXp1GjRowfP5777ruPCy+8kP3795OamsqYMWNo3rw5d999N+eeey4nnHBCvkMdc9SvX582bdqwceNG/vrXv3LaaafRrVs3Pv30Uxo0aICZ8Y9//INTTjmFf//73zz44IOkpqZSrly53KGfscR05plnFhjDqaeeyqhRo2jTpg3OOTp16kTXrl2jfnYiIpIY9957L82aNaNy5co0a9as0GQpr3//+9/ccMMN7Ny5k7POOosXXngh6j6jRo2ic+fOVK5cmfT09Nz7wwtSv359UlJSaNCgAf369csd7pijatWq/N///R9169alevXqBd4XHmnYsGHcfPPN1K1bl5SUFEaMGEH37t1p2LAhZ599NlWrVs0djgheW9ixY0dOO+00Zs2axbhx47j88stzJ/247777qFmzJs888wydOnWibNmytGzZssDPsmnTpvTo0YN169bRp08f0tPTAe9qaNOmTQHvcTwNGzbk/fffZ/DgwRx11FGkpqbm/jAca0wFKVOmDC+88AKXXXYZ2dnZNGnSpFgzj4pI/qyw4QRBSE9PdzkzDeZYvnw555xzTkARSY5o00Mng3j8LZX0lN3FEa9JTFq3bg0Q92FQcuQxs0XOufSg40gWaiMlP+PGjTtg0rRko7/hIph18K0nkZL9KlqRHiPQ5uCLJUEoqT5TYe2jrsCJSFQ5U3KLiIiISMES0WdSAicxu/vuu4MOQQKiBE5EJLz69eun72mRkEjE/4uaxEREotq7d2/uxDMiIiIikr9E9JmS5gqcc67QKXFFognb/Z7JpH379oDugRMREREpTCL6TEmRwJUpU4ZNmzZRsWJFJXFSJM45Nm3aRJkyZYIOJSldd911QYcgIiIiEnqJ6DMlRQJXpUoV1q1bx4YNG4IORZJYmTJlqFKlStBhJKU+ffoEHYKIiIhI6CWiz5QUCVxqairVq1cPOgyRI9bOnTsBKFu2bMCRiEgYrVu3jptvvplly5axf/9+OnfuzIMPPkjp0qVDO8V9uXLloj4v7oEHHuDOO+9MUEQHmzx5MjVr1qR27dqHtF8s5xbLo4GKWr/IkSwRfaakSOBEJFgXX3wxoHvgRJJClOdDHbIoz1ZyztG9e3duvPFGpkyZwr59++jfvz933XUXDz74YMnG4svOzqZUqfh3YQpK4JxzOOc46qj4zgU3efJkOnfuHFgCFXT9IskoEX0mzUIpIlHdeOON3HjjjUGHISIh9MEHH1CmTBmuueYaAFJSUvjnP//J2LFjc3+JXrt2La1bt6ZGjRrcc889AOzYsYNOnTrRoEED6tatS0ZGBgCLFi2iVatWNG7cmA4dOvDDDz8A3sNxBw0aRHp6Ovfffz9nnnkm+/fvzz1W1apV2bt3L6tWraJjx440btyYli1b8tVXXwHw7bffcu6551KvXj2GDRsW9byGDh3Krl27SEtL48orryQrK4tatWpx9dVXU7duXdauXUu5cuVyy0+cODF3+vB+/foxcOBAzjvvPM466ywmTpyYW+7vf/879erVo0GDBgwdOhSAZ599liZNmtCgQQN69OjBzp07mTt3LlOnTmXw4MGkpaWxatWqYp/b/fffT82aNfnDH/7AihUrctfHWn9+5UTkQInoM+kKnBzRqg19J+gQkkKvXr2CDkFEQurLL7+kcePGB6w7/vjjOeOMM1i5ciUA8+fPZ+nSpZQtW5YmTZrQqVMn1qxZw2mnncY773jfw1u3bmXv3r3ccsstTJkyhcqVK5ORkcFdd93F2LFjAdizZw8LFy4EYPHixXz44Ye0adOGt99+mw4dOpCamkr//v156qmnqFGjBp999hk33XQTH3zwAbfeeis33ngjV199NWPGjIl6XqNGjeLxxx8nMzMTgKysLL755hv+/e9/07x586j7//DDD8yZM4evvvqKLl260LNnT959912mTJnCZ599RtmyZdm8eTMA3bt35/rrrwdg2LBhPP/889xyyy106dKFzp0707NnTwDatm1b5HNbtGgREyZMIDMzk+zsbBo1apT77xZr/SeccEK+5UTkfxLRZ1ICJyJRbd26FYDy5csHHIlI+JhZR+BRIAV4zjk3Ks/2G4CbgX3AdqC/c26Zv+0O4I/+toHOufcTGXuitG/fnooVKwJesjBnzhwuvvhibrvtNoYMGULnzp1p2bIlS5cuZenSpbnTcO/bt49TTz019ziRHaNevXqRkZFBmzZtmDBhAjfddBPbt29n7ty5XHbZZbnlfvvtNwA++eQT3njjDQCuuuoqhgwZcsjnceaZZ8aUvAFceumlHHXUUdSuXZuffvoJgBkzZnDNNdfk3htz4oknArB06VKGDRvGL7/8wvbt2+nQocNBxyvuuX388cd069Ytt+4uXbrkboul/kMpJ3IkS0SfKaYELobG6c/AdUA2sAG41jm3xt+2D1jiF/3OOdcFEUkqXbt2BXQPnEheZpYCjAHaA+uABWY2NSdB873inHvKL98FeAToaGa1gd5AHeA0YIaZ1XTO7UvoSRRT7dq1DxgiCPDrr7/y3Xff8fvf/57Fixcf9AggM6NmzZosXryYadOmMWzYMNq2bUu3bt2oU6cOn376ab51HXvssbnvu3Tpwp133snmzZtZtGgRF1xwATt27OCEE07IvWqWV3EfRRRZf97j7d69+4BtRx99dO77aM8h7devH5MnT6ZBgwaMGzcu3+/a/fv3x+3cYqn/UMqJHMkS0WeKeg9cRON0EVAbuNxvdCL9F0h3ztUHJgL/iNi2yzmX5r+UvIkkoYEDBzJw4MCgwxAJo6bASufcaufcHmAC0DWygHPu14jFY4Gc3nxXYIJz7jfn3LfASv94SaVt27bs3LmTF198EfCumt12223069cv92rP9OnT2bx5M7t27WLy5Mm0aNGC77//nrJly9KnTx8GDx7M4sWLqVWrFhs2bMhN4Pbu3cuXX36Zb73lypWjSZMm3HrrrXTu3JmUlBSOP/54qlevzuuvvw54idPnn38OQIsWLZgwYQIA48ePP+BYZ599dr51pKamsnfv3gLP/eSTT2b58uXs37+fSZMmRf2s2rdvzwsvvJB771jOEMpt27Zx6qmnsnfv3gNiO+6449i2bRtAkc8tx/nnn8/kyZPZtWsX27Zt46233srdFkv9hZUTKWmjZ359yK9qQ9/JfQUpEX2mWCYxiaVxmuWcy7mTdR6gh22JHEa6d+9O9+7dgw5DJIxOB9ZGLK/z1x3AzG42s1V4P3AOPJR9w87MmDRpEq+//jo1atSgZs2alClThgceeCC3TNOmTenRowf169enR48epKens2TJEpo2bUpaWhr33HMPw4YNo3Tp0kycOJEhQ4bQoEED0tLSmDt3boF19+rVi5dffvmAoZXjx4/n+eefp0GDBtSpU4cpU6YA8OijjzJmzBjq1avH+vXrc8tv3LixwCtk/fv3p379+lx55ZX5bh81ahSdO3fmvPPOO2CoZ0E6duxIly5dSE9PJy0tjYceegiAe++9l2bNmtGiRYsDksnevXvz4IMP0rBhQ1atWnXI5xapUaNG9OrViwYNGnDRRRfRpEmT3G2x1l9QORH5n0T0mSzaZX0z6wl0dM5d5y9fBTRzzg0ooPzjwI/Oufv85WwgE2945Sjn3OTC6ktPT3c5NyiLxFvQv9KUtKxRneJy3I0bNwJQqVKluBxfjlxmtsg5lx50HEVVhDbyCqCDc66v317Oc8697G97HnjXOTcxzz79gf4AZ5xxRuM1a9YccMzly5dzzjnnlPCZHTnefvttVq9erVEGAdLfcBFEeVzI6JlfJyiQ8Bid3TP3fbz6Q7EoqT5TYe1jiU5iYmZ9gHSgVcTqM51z683sLOADM1vinFuVZ7/IxqkkQxKREpAzA5nudxA5yHqgasRyFX9dQSYATx7Kvs65Z4BnwPuRszjBysE6d+4cdAgichhJRJ8plgQupgbGzNoBdwGtnHO/5ax3zq33/7vazGYDDYEDEjg1TiLhdttttwUdgkhYLQBqmFl1vLaxN3BFZAEzq+Gc+8Zf7ATkvJ8KvGJmj+BNYlIDmJ+QqEVEJC4S0WeKJYGLpXFqCDyNN4zk54j1FYCdzrnfzKwS0IIDJzgRkSRwySWXBB2CSCg557LNbADwPt5MzWOdc1+a2UhgoXNuKjDA/5FzL7AF6Ovv+6WZvQYsw7vN4OZkm4FSREQOlIg+U9QELsbG6UGgHPC6P41tzuMCzgGeNrP9eBOmjMoztbKIJIEff/wRgFNOOSXgSETCxzk3DZiWZ93wiPe3FrLv/cD9JRBDsafIFwlCtLkYRJJNIvpMMd0DF0Pj1K6A/eYC9YoToIgEr3fv3oDugRMJozJlyrBp0yYqVqyoJE6SinOOTZs2UaZMmaBDESkxiegzlegkJiJyeBo6dGjQIYhIAapUqcK6devYsGFD0KGIHLIyZcpQpYqePiWHj0T0mZTAiUhUHTt2DDoEESlAamoq1atXDzoMESkhsTziaFCpI+8xAckiEX2mWB7kLSJHuLVr17J27droBUVERESOYInoM+kKnIhEddVVVwG6B05ERESkMInoMymBE5Gohg0bFnQIIiIiIqGXiD6TEjgRiapdu3wnmhURERGRCInoM+keOBGJavXq1axevTroMERERERCLRF9Jl2BE5Gorr32WkD3wImIiIgUJhF9JiVwIhLVPffcE3QIIiIiIqGXiD6TEjgRiapVq1ZBhyAiIiISeonoM+keOBGJasWKFaxYsSLoMERERERCLRF9Jl2BE5Go/vSnPwG6B05ERESkMInoMymBE5GoHnjggaBDEBEREQm9RPSZlMCJSFTnnXde0CGIiIiIhF4i+ky6B05Eolq6dClLly4NOgwRERGRUEtEn0lX4EQkqgEDBgC6B05ERESkMInoMymBE5GoHnzwwaBDEBEREQm9RPSZlMCJSFRNmjQJOgQRERGR0EtEn0n3wIlIVJmZmWRmZgYdhoiIiEioJaLPpCtwIhLVoEGDAN0DJyIiIlKYRPSZlMCJSFSjR48OOgQRERGR0EtEn0kJnIhElZaWFnQIIiIiIqGXiD6T7oETkagWLFjAggULgg5DREREJNQS0WfSFTgRiWrw4MGA7oETERERKUwi+kxK4EQkqscffzzoEERERERCLxF9JiVwIhJV3bp1gw5BREREJPQS0WfSPXAiEtXcuXOZO3du0GGIiIiIhFoi+ky6AiciUd15552A7oETERERKUwi+kxK4EQkqqeffjroEERCycw6Ao8CKcBzzrlRebb/GbgOyAY2ANc659b42/YBS/yi3znnuiQscBERiYtE9JmUwIlIVLVq1Qo6BJHQMbMUYAzQHlgHLDCzqc65ZRHF/gukO+d2mtmNwD+AXv62Xc65tETGLCIi8ZWIPpPugRORqD788EM+/PDDoMMQCZumwErn3Grn3B5gAtA1soBzbpZzbqe/OA+okuAYRUQkgRLRZ9IVOBGJasSIEYDugRPJ43RgbcTyOqBZIeX/CLwbsVzGzBbiDa8c5ZybnN9OZtYf6A9wxhlnFCdeERGJs0T0mZTAiRxGqg19p9jHyBrV6aB1Y8eOLfZxRY5kZtYHSAdaRaw+0zm33szOAj4wsyXOuVV593XOPQM8A5Cenu4SErCIiBRJIvpMSuBEJKqzzjor6BBEwmg9UDViuYq/7gBm1g64C2jlnPstZ71zbr3/39VmNhtoCByUwImISPJIRJ9JCZyIRDVjxgwA2rVrF3AkIqGyAKhhZtXxErfewBWRBcysIfA00NE593PE+grATufcb2ZWCWiBN8GJiIgUU3FHJOU3GilWiegzKYETkajuu+8+QAmcSCTnXLaZDQDex3uMwFjn3JdmNhJY6JybCjwIlANeNzP43+MCzgGeNrP9eBOKjcoze6WIiCShRPSZYkrgivmcm77AML/ofc65f5dQ7CKSIC+99FLQIYiEknNuGjAtz7rhEe/zbcGdc3OBevGNTkSCUBL3o0vySkSfKWoCV5zn3JjZicAIvBu3HbDI33dLSZ+IiMRP1apVoxcSEREROcIlos8Uy3PgivOcmw7AdOfcZj9pmw50LJnQRSRR3nvvPd57772gwxAREREJtUT0mWIZQlmc59zkt+/phxKgiARv1Chv1HTHjvr9RURERKQgiegzlegkJgU85yaW/fSQUpEQmzBhQtAhiIiIiIReIvpMsSRwxXnOzXqgdZ59Z+fdVw8pFQm3U045JegQREREREIvEX2mWBK4Ij/nBm9q5Qf8590AXAjcUeyoRSSh3nrrLQAuueSSgCMRERERia/izCS6c+VnPN+3SVz7TFETuOI858Y5t9nM7sVLAgFGOuc2x+VMRCRuHn74YUAJnIiIiEhhfp0/iYc3fhRsAgdFf86Nv20sMLaoAYpI8CZOnBh0CCIiIiKhV/nSO5g4/MK41lGik5iIyOGpUqVKQYcgIiIiEnopZcvHvd8Uy3PgROQI9+abb/Lmm28GHYaIiIhIqO1cMTfufSYlcCIS1WOPPcZjjz0WdBgiIiIiofbroqlx7zNpCKUkreLMECSHZsqUKUGHICIiIlKgQaXCcb/+rv+rz01/fSKudSiBE5GoypcvH3QIIiIiIqF3TJnSce83aQiliESVkZFBRkZG0GGIiIiIhNp/l30b9z6TEjgRierJJ5/kySefDDoMERERkVD7ZNFXce8zaQiliEQ1bdq06IVEREREjnD9e7dnwIin41qHEjgRiaps2bJBhyAiIiISeqVTS8W936QhlCIS1csvv8zLL78cdBgiIiIiobZwyaq495mUwIlIVM899xzPPfdc0GGIiIiIhNq8zK/j3mfSEEoRiWr69OlBhyAiIiISejde0YGBI5XAiUjAUlNTgw5BREQkeLP+FrXIoFJfJyAQCauUlKPi3m/SEEoRiWrcuHGMGzcu6DBEREREQm3+59/Evc+kBE5EolICJyIiIhLd/C9Wxr3PpCGUIhLV7Nmzgw5BREREJPQGXHURg+57Ia516AqciIiIiIhIklACJyJRPfvsszz77LNBhyEiIiISap/+9+u495mUwIlIVBkZGWRkZAQdhoiIiEio/XfZt3HvM+keOBGJasaMGUGHICIiIhJ6N13ZQffAiYiIhJWZdTSzFWa20syG5rP9z2a2zMy+MLOZZnZmxLa+ZvaN/+qb2MhFRCRZKYETkaieeOIJnnjiiaDDEAkVM0sBxgAXAbWBy82sdp5i/wXSnXP1gYnAP/x9TwRGAM2ApsAIM6uQqNhFRCQ+5iz8Ku59JiVwIhLVW2+9xVtvvRV0GCJh0xRY6Zxb7ZzbA0wAukYWcM7Ncs7t9BfnAVX89x2A6c65zc65LcB0oGOC4hYRkTj58pu1ce8z6R44EYnq3XffDToEkTA6HVgbsbwO74paQf4I5PzPlN++p5dodCIiknB/urx93O+BUwInIiISZ2bWB0gHWhVh3/5Af4AzzjijhCMTEZFkoyGUIhLVo48+yqOPPhp0GCJhsx6oGrFcxV93ADNrB9wFdHHO/XYo+wI4555xzqU759IrV65cIoGLiEh8fDh/Wdz7TErgRCSqmTNnMnPmzKDDEAmbBUANM6tuZqWB3sDUyAJm1hB4Gi95+zli0/vAhWZWwZ+85EJ/nYiIJLFvsn6Ie59JQyhFJKqpU6dGLyRyhHHOZZvZALzEKwUY65z70sxGAgudc1OBB4FywOtmBvCdc66Lc26zmd2LlwQCjHTObQ7gNEREpARd939tdQ+ciIhIWDnnpgHT8qwbHvG+XSH7jgXGxi86ESmKakPfKXDboFJfJzASkfxpCKWIRPXQQw/x0EMPBR2GiIiISKjNmrc07n0mXYETkQPk98vjhkmTAHh84zkxHydrVKcSi0lEREQkGWSt20CpTz+Nax1K4EQkqsrd7gw6BBEREZHQu6Znm7jfA6chlCIiIiIiIklCCZyIRLV13utsnfd60GGIiIiIhNqMuV8watSouNahIZQiEtWen1YHHYKIiIhI6H3/02YyMzPjWocSOBGJqnLXIUGHICIiIhJ6V3drHY574Myso5mtMLOVZjY0n+3nm9liM8s2s555tu0zs0z/pacBi4iIiIiIFFHUBM7MUoAxwEVAbeByM6udp9h3QD/glXwOscs5l+a/uhQzXhEJwC+fvMovn7wadBgiIiIiofafjz/n3nvvjWsdsQyhbAqsdM6tBjCzCUBXYFlOAedclr9tfxxiFJGAZW9eH3QIIiIiIqH38+atrFixIq51xJLAnQ6sjVheBzQ7hDrKmNlCIBsY5ZybfAj7ikgIVLrkL0GHICIiIhJ6fbqeH/d74BIxicmZzrn1ZnYW8IGZLXHOrYosYGb9gf4AZ5xxRgJCEhERERERST6xTGKyHqgasVzFXxcT59x6/7+rgdlAw3zKPOOcS3fOpVeuXDnWQ4tIgvzy8cv88vHLQYchIiIiEmrvfvhfhg8fHtc6YkngFgA1zKy6mZUGegMxzSZpZhXM7Gj/fSWgBRH3zolIcsj+dSPZv24MOgwRERGRUPvl1x2sXbs2esFiiDqE0jmXbWYDgPeBFGCsc+5LMxsJLHTOTTWzJsAkoAJwiZnd45yrA5wDPO1PbnIU3j1wSuBEkkylToOCDkFEREQk9C6/5A/huAfOOTcNmJZn3fCI9wvwhlbm3W8uUK+YMYqIiIiIiAgxPshbRI5sWz4cx5YPxwUdhoiIiEiovT1rEXfccUdc60jELJQikuT279oWdAgiIiIiobdj529s2rQprnUogZNAVBv6TtAhyCGo2PGWoEMQERERCb1enc5j0H3PxLUODaEUERERERFJEkrgRCSqLR88z5YPng86DBEREZFQmzJjAX/5y1/iWoeGUIpIVPuz9wQdgoiIiEjo7c3ex65du+JahxI4EYmq4oU3Bh2CiIiISOj17NicQfeNiWsdGkIpIiIiIiKSJHQFTkSi2jzDm03pxHb9A45ERESkcJrpWoI06T+fkbV9EKNHj45bHboCJyIiIiIikiR0BU5EotKVNxEREZHoul3YjEH3jY5rHboCJyIiIiIikiSUwIlIVJv+8ySb/vNk0GGIiIiIhNrE9+Zx8803x7UOJXAiEtVRpUpzVKnSQYchEjpm1tHMVpjZSjMbms/2881ssZllm1nPPNv2mVmm/5qauKhFRCReUkulcMwxx8S1Dt0DJyJRVbjgj0GHIBI6ZpYCjAHaA+uABWY21Tm3LKLYd0A/4C/5HGKXcy4t3nGKiEjidG3XhEH3PRTXOpTAiYiIFE1TYKVzbjWAmU0AugK5CZxzLsvftj+IAEVE5PCjIZQiEtWm9/7Fpvf+FXQYImFzOrA2Ynmdvy5WZcxsoZnNM7NLCypkZv39cgs3bNhQxFBFRCQRMt6ZS//+8Z29W1fgRCSqo445LugQRA5HZzrn1pvZWcAHZrbEObcqbyHn3DPAMwDp6eku0UGKiEjsji17NBUrVoxrHUrgRCSqCq36BR2CSBitB6pGLFfx18XEObfe/+9qM5sNNAQOSuBERCR5dG7TmEH3/S2udWgIpYiISNEsAGqYWXUzKw30BmKaTdLMKpjZ0f77SkALIu6dExERKYgSOBGJauM7o9n4zuigwxAJFedcNjAAeB9YDrzmnPvSzEaaWRcAM2tiZuuAy4CnzexLf/dzgIVm9jkwCxiVZ/ZKERFJQq++NYdrrrkmrnVoCKWIRFXq+EpBhyASSs65acC0POuGR7xfgDe0Mu9+c4F6cQ9QREQS6oTjj6Vq1arRCxaDEjgRieqEln2CDkFEREQk9C5q1ZBBI0fGtQ4lcCIiIiISXrMObUKIQaW+jlMgIuGge+BEJKqNbz3ExrceCjoMERERkVB7ecpH9OkT35FLugInIlGVOvFQnk0sIiIicmQ66cTy1KpVK651KIETkahOaHF50CGIiIiIhN6FLRsw6K9/jWsdGkIpIiIiIiKSJJTAiUhUG6b8nQ1T/h50GCIiIiKh9uKk2fTu3TuudWgIpYhEVfrks4IOQURERCT0Tjv5RNLS0uJahxI4EYmqfPPLgg5BREREJPTanVefQUOHxrUODaEUERERERFJEkrgRCSqDZMeYMOkB4IOQ0RERCTUXpg4ix49esS1Dg2hFJGoSp92dtAhiIiIiIRetSqVOffcc+NahxI4EYmqfLPuQYcgIiIiEnptmtdl0F/+Etc6NIRSREREREQkScSUwJlZRzNbYWYrzeygaVXM7HwzW2xm2WbWM8+2vmb2jf/qW1KBi0ji/PzGSH5+Y2TQYYiIiIiE2nOvzaRLly5xrSPqEEozSwHGAO2BdcACM5vqnFsWUew7oB/wlzz7ngiMANIBByzy991SMuGLSCKUObNB0CGIiIiIhF6NaqfSum3buNYRyz1wTYGVzrnVAGY2AegK5CZwzrksf9v+PPt2AKY75zb726cDHYFXix25iCTM8eldgw5BREREJPRaNa3NrbfeGtc6YhlCeTqwNmJ5nb8uFjHta2b9zWyhmS3csGFDjIcWERERERE5soRiEhPn3DPOuXTnXHrlypWDDkdE8vjptRH89NqIoMMQERERCbWnX53ORRddFNc6YhlCuR6oGrFcxV8Xi/VA6zz7zo5xXxEJibK/bxp0CCIiIiKhV6dGVS645JK41hFLArcAqGFm1fESst7AFTEe/33gATOr4C9fCNxxyFGKSKCOa9Qp6BBEROQIUW3oOwcsDyr1dUCRiBy6P6SfzU033RTXOqIOoXTOZQMD8JKx5cBrzrkvzWykmXUBMLMmZrYOuAx42sy+9PfdDNyLlwQuAEbmTGgiIiIiIiIihyaWK3A456YB0/KsGx7xfgHe8Mj89h0LjC1GjCISsJ8m3AXAyb3vDzgSERERkfB6Yvz7vD2vHTNmzIhbHTElcCJyZCt7dsugQxAREREJvYa1q9Pu0l5xrUMJnIhEdVxax6BDEBEREQm9cxvW5Prrr49rHaF4jICIiIiIiIhEpwRORKL68ZWh/PjK0KDDEBEREQm1x196l9atW8e1Dg2hFJGoytVrd8j75J0GuiiyRunxBSIiIpI8mtb/PRf26BfXOnQFTkSiKlevXZGSOJHDnZl1NLMVZrbSzA66TG1m55vZYjPLNrOeebb1NbNv/FffxEUtIiLx0rRBDfr16xfXOpTAiUhUbl82bl920GGIhIqZpQBjgIuA2sDlZlY7T7HvgH7AK3n2PREYATQDmgIjzKxCvGMWEZH42rdvP3v37o1rHUrgRCSqnzKG8VPGsKDDEAmbpsBK59xq59weYALQNbKAcy7LOfcFsD/Pvh2A6c65zc65LcB0QNO9iogkuSdfeZ/27dvHtQ7dAyciUZVr0CHoEETC6HRgbcTyOrwrakXd9/QSiktERALSPK0mHS+7Lq51KIETkajK1WkTdAgiRywz6w/0BzjjjDMCjkZERAqTXu939OnTJ651aAiliES1f+9u9u/dHXQYImGzHqgasVzFX1ei+zrnnnHOpTvn0itXrlykQEVEJDH27M1m586dca1DCZyIRPXz63fz8+t3Bx2GSNgsAGqYWXUzKw30BqbGuO/7wIVmVsGfvORCf52IiCSxZyZM5+KLL45rHRpCKSJRHdcwvl9EIsnIOZdtZgPwEq8UYKxz7kszGwksdM5NNbMmwCSgAnCJmd3jnKvjnNtsZvfiJYEAI51zmwM5ERERKTEtGp/Nxb1ujGsdSuBEJKpjzzk/6BBEQsk5Nw2Ylmfd8Ij3C/CGR+a371hgbFwDFBGRhGpYuzq9evWKax1K4OSQVBv6TtAhSAD2/7YDgKOOPjbgSERERETCa9fuPWzdupXy5cvHrQ7dAyciUf38xr38/Ma9QYchIiIiEmrPvz6Trl27Ri9YDLoCJyJRHd+4S9AhiIiIiITe+U1q0/mKAXGtQwmciERVttZ5QYcgIiIiEnr1zz6T7t27x7UODaEUkaj27dzKvp1bgw5DREREJNS279zNxo0b41qHEjgRiWrD5L+xYfLfgg5DREREJNTGvTGLnj17xrUODaEUkaiOb9ot6BBEREREQq91s7p0verWuNahBE5Eoir7+2ZBhyAiIiISenVrVuWSSy6Jax0aQikiUe3bvoV927cEHYaIiIhIqP26fSc//vhjXOtQAiciUW2Y+nc2TP170GGIiIiIhNqLkz6kd+/eca1DQyhFJKryzS8LOgQRERGR0Gt7Xj269b0trnUogRORqI45q3HQIYiIiIiE3jm/q0LHjh3jWoeGUIpIVNm/biD71w1BhyEiIiISalt+3cHatWvjWocSOBGJauPbD7Px7YeDDkNEREQk1MZP+YirrroqrnVoCKWIRFX+vPjejCsiIiE0629F2m30zK+LVe0g9U4libX/QwN6XDM4rnXofxERieqYamlBhyAiIiISerWqn0a7du3iWoeGUIpIVHt/+ZG9v8T3mSYiIiIiyW7jlm2sXr06rnUogRORqDZNG82maaODDkNEREQk1Ca8PYdrr702rnVoCKWIRHXCH64MOgQRERGR0Ot4fkMuu25oXOtQAiciUZU5o17QIYiIiIiE3u/PPIVWrVrFtQ4NoRSRqPZuWsfeTeuCDkNEREQk1H7etJUVK1bEtY6YEjgz62hmK8xspZkddE3QzI42swx/+2dmVs1fX83MdplZpv96qoTjF5EE2PT+42x6//GgwxAREREJtdemzeVPf/pTXOuIOoTSzFKAMUB7YB2wwMymOueWRRT7I7DFOfd7M+sN/B3o5W9b5ZxLK9mwRSSRTji/b9AhiIiIiIRepzaN6dX/rrjWEcsVuKbASufcaufcHmAC0DVPma7Av/33E4G2ZmYlF6aIBKlMlXMoU+WcoMMQERERCbXqVU7ivPPOi2sdsSRwpwNrI5bX+evyLeOcywa2AhX9bdXN7L9m9qGZtSxmvCISgD0bstizISvoMERERERC7Yeft7B06dK41hHvSUx+AM5wzjUE/gy8YmbH5y1kZv3NbKGZLdywYUOcQxKRQ7V5+lNsnq5bWEVEREQK88b78xgwYEBc64jlMQLrgaoRy1X8dfmVWWdmpYDywCbnnAN+A3DOLTKzVUBNYGHkzs65Z4BnANLT010RzkNE4qhCm/g+kFJERETkcNClbRMuv3F4XOuI5QrcAqCGmVU3s9JAb2BqnjJTgZxZDnoCHzjnnJlV9idBwczOAmoAq0smdBFJlKNPrcnRp9YMOgwRERGRUDvjtEo0adIkrnVETeD8e9oGAO8Dy4HXnHNfmtlIM+viF3seqGhmK/GGSuY8auB84Aszy8Sb3OQG59zmEj4HEYmzPT+tZs9P+u1FREREpDDrf9xEZmZmXOuIZQglzrlpwLQ864ZHvN8NXJbPfm8AbxQzRhEJ2OaZzwBwyhWjAo5EJFzMrCPwKJACPOecG5Vn+9HAi0BjYBPQyzmX5T8vdTmQ87TXec65GxIWuIiIxMWk6fNZsHIQs2fPjlsdMSVwInJkO7Ft/0DqrTb0nWIfI2tUpxKIRORgek6qiIjk1a19U64cMDKudcR7FkoROQyUPvksSp98VtBhiISNnpMqIiIHOP2UiqSlpcW1DiVwIhLVbz98zW8/fB10GCJhk5DnpOpROyIiyeO77zeyYMGCuNahBE5EotoyayxbZo0NOgyRw0lMz0kF71E7zrl051x65cqVExqkiIgcmqkzFzB48OC41qF74EQkqhPba24FkXzE/TmpIiKSXHp0aM5VA++Lax1K4I4gJTEhhByZSleuFnQIImGU+5xUvEStN3BFnjI5z0n9lDzPSQU2O+f26TmpEjY5/YVBpTR0XuRQnXpSBerWrRvXOjSEUkSi2r1uObvXLQ86DJFQ0XNSRUQkr2/X/czcuXPjWoeuwIlIVL985E2ip+fAiRxIz0kVEZFI78xaxOdZd+o5cCISrIodBgQdgoiIiEjo/d/F59F30N/iWocSOBGJKrVilaBDEBEREQm9kyqWp1atWnGtQ/fAiUhUu79bwu7vlgQdhoiIiEiorVzzIx9++GFc61ACJyJR/TJnPL/MGR90GCIiIiKh9t5H/2XEiBFxrUNDKEUkqooXDwo6BBEREZHQ6935D1x72z/iWocSOBGJKvWEU4IOQURERCT0KlU4jrPOOiuudWgIpYhEtSsrk11ZmUGHISIiIhJqK779nhkzZsS1Dl2BE5Gots6dAMAx1dKCDUREREQkxKbP+Zzl6++jXbt2catDCZyIRFWp821BhyAicmSYVfznR42e+XWx9h+k3qFIkV3Z9XyuG/xwXOvQ/6IiElWp4ysHHYKIiIhI6FU4/liqVq0a1zp0D5yIRLVr9SJ2rV4UdBgiIiIiobZ81Tree++9uNahK3AiEtXWea8DcMxZjQOORERERCS8Zs5dwoofRtGxY8e41aEETkSiqtxlSNAhiIiIiITe1d1a0X/oo3GtQwmciESVUq5C0CEUWbWh7xT7GFmjOpVAJCIiInK4O75cWU45Jb7Pz9U9cCIS1c6Vn7Fz5WdBhyEiIiISaku/Xstbb70V1zqUwIlIVL/On8Sv8ycFHYaIiIhIqM3+bCkPP6zHCAglMwxMpKgqX3pH0CGIiIiIhF6/Hm244c5/xbUOJXAiElVK2fJBhyAicsQo7oO4RSQ45cqWoVKlSnGtQ0MoRSSqnSvmsnPF3KDDEBEREQm1L75aw5tvvhnXOpTAiUhUvy6ayq+LpgYdhoiIiEiofbRgGY899lhc69AQShGJ6qQefw06BBEREZHQ++Nlbbnpr0/EtQ4lcCIS1VFHHxt0CCIiIiKhd0yZ0pQvH9+5AzSEUkSi2rH8I3Ys/yjoMERERERC7b/LviUjIyOudegKnIhEte2/0wA49pzzA45ERCTcivvYn0GlNAOlSDL7ZNFXrNn8JL169YpbHUrgRCSqky67O+gQREREREKvf+/2DBjxdFzrUAInIlEdlVom6BACVdxf1HNkjepUIscRERGRcCqdWoqyZcvGtQ4lcHFWUh0/kSBt/3IWAOXqtAk4EhEREZHwWrhkFS+//DJ9+vSJWx1K4EQkqu2fvw8ogSuukvhBR1fxREREwmte5tes2/pcXBO4mGahNLOOZrbCzFaa2dB8th9tZhn+9s/MrFrEtjv89SvMrEMJxi4iCXJyr/s4udd9QYchEkpqI0VEJMeNV3Rg+vTpca0j6hU4M0sBxgDtgXXAAjOb6pxbFlHsj8AW59zvzaw38Hegl5nVBnoDdYDTgBlmVtM5t6+kT0RE4sdSdLFeJD9qIw8Ts/6W+3b0zOLNAjlIX5ciR7SUlKNITU2Nax2xfM00BVY651YDmNkEoCsQ2Th1Be72308EHjcz89dPcM79BnxrZiv9431aMuHHl+5fE/FsXzIDgHL12gUciUjoHLFtZFgVpe3W1P0iUlLmf/4N48aNo1+/fnGrI5YE7nRgbcTyOqBZQWWcc9lmthWo6K+fl2ff04sc7SFQ8iVScpTAhUeYvtt0Px6QjG1kxNWmeDiUK1ijs3vGMRIRkcSb/8VKvt8WfAIXd2bWH+jvL243sxWHeIhKwMaSjSppHMnnDkf2+Sf83Nf8vXMiqyvMkfzvDiE5f/t7iRzmzBI5ymGsBNrIklaCf3/jSuYwh+ag+P9fEFEUXSj+/y8GxR8sxZ8Aq777CW+gxUEOJf4C28dYErj1QNWI5Sr+uvzKrDOzUkB5YFOM++KcewZ4JoZY8mVmC51z6UXdP5kdyecOR/b569yPzHMHnX/IhL6NLGnJ/ven+IOl+IOl+INVUvHHMgvlAqCGmVU3s9J4N1xPzVNmKtDXf98T+MA55/z1vf0ZuKoDNYD5xQ1aREQkJNRGiohIQkW9AueP1x8AvA+kAGOdc1+a2UhgoXNuKvA88JJ/A/ZmvAYMv9xreDdzZwM3a3YtERE5XKiNFBGRRIvpHjjn3DRgWp51wyPe7wYuK2Df+4H7ixFjLEIztCQAR/K5w5F9/jr3I9eRfv6hkgRtZElL9r8/xR8sxR8sxR+sEonfvFEcIiIiIiIiEnax3AMnIiIiIiIiIXDYJHBmdouZfWVmX5rZP4KOJwhmdpuZOTOrFHQsiWJmD/r/7l+Y2SQzOyHomOLNzDqa2QozW2lmQ4OOJ5HMrKqZzTKzZf7/67cGHVOimVmKmf3XzN4OOhY5siVzu2tmd5vZejPL9F8XBx1TUSRru29m9/rtdqaZ/cfMTgs6pkOR7H0PM7vM//92v5klzYyOydz/MbOxZvazmS0tieMdFgmcmbUBugINnHN1gIcCDinhzKwqcCHwXdCxJNh0oK5zrj7wNXBHwPHElZmlAGOAi4DawOVmVjvYqBIqG7jNOVcbaA7cfISdP8CtwPKgg5Aj22HS7v7TOZfmv6ZFLx4uSd7uP+icq++cSwPeBoZHKR82yd73WAp0Bz4KOpBYHQb9n3FAx5I62GGRwAE3AqOcc78BOOd+DjieIPwTuB04om5qdM79xzmX7S/Ow3uO0uGsKbDSObfaObcHmIDXiToiOOd+cM4t9t9vw0tkTg82qsQxsypAJ+C5oGORI57a3eAlbbvvnPs1YvFYkuwckr3v4Zxb7pxbEXQchyip+z/OuY/wZiEuEYdLAlcTaGlmn5nZh2bWJOiAEsnMugLrnXOfBx1LwK4F3g06iDg7HVgbsbyOIyiBiWRm1YCGwGcBh5JIo/E6bPsDjkPkcGh3B/hD4MaaWYWggzkUh0O7b2b3m9la4EqS7wpcpCOh7xEG6v9EiOkxAmFgZjOAU/LZdBfeeZyIN6SqCfCamZ3lDqMpNqOc/514wygOS4Wdu3Nuil/mLrzhdeMTGZsEw8zKAW8Ag/L8knvYMrPOwM/OuUVm1jrgcOQIkOztbpT4nwTuxbvycy/wMF5HPDSSvd2P1nY75+4C7jKzO4ABwIiEBhhFsvc9YolfklfSJHDOuXYFbTOzG4E3/YZjvpntByoBGxIVX7wVdP5mVg+oDnxuZuBdxl9sZk2dcz8mMMS4KezfHsDM+gGdgbZh6jzEyXqgasRyFX/dEcPMUvGSt/HOuTeDjieBWgBd/MkWygDHm9nLzrk+Acclh6lkb3ejtR05zOxZvPuwQiXZ2/1YP3+85GcaIUvgkr3vcQiff7I44vs/kQ6XIZSTgTYAZlYTKA1sDDKgRHHOLXHOneScq+acq4Z3SblRmL7E48nMOuINKevinNsZdDwJsACoYWbVzaw00BuYGnBMCWNeb+V5YLlz7pGg40kk59wdzrkq/v/nvYEPlLxJgCaTxO2umZ0asdgNb1KHpHA4tPtmViNisSvwVVCxFMUR2PcIgyO6/5NX0lyBi2IsMNafmnMP0DeMv4ZIXDwOHA1M93+JnOecuyHYkOLHOZdtZgOA94EUYKxz7suAw0qkFsBVwBIzy/TX3ZmMM8iJJLlkb3f/YWZpeEMos4A/BRrNkWeUmdXCu593DZBs7XZS9z3MrBvwL6Ay8I6ZZTrnOgQcVqGSvf9jZq8CrYFKZrYOGOGce77Ix0uu71sREREREZEj1+EyhFJEREREROSwpwROREREREQkSSiBExERERERSRJK4ERERERERJKEEjgREREREZEkoQROREREREQkSSiBExERERERSRJK4ERERERERJLE/wfY/OdefkXyUwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f, axes = plt.subplots(1, 2, figsize=(15, 5))\n", "\n", "axes[0].hist(\n", " samples_thinned.copy(), label=\"Untruncated posterior\", bins=20, density=True\n", ")\n", "axes[0].axvline(true_high, linestyle=\":\", color=\"k\", label=\"Truncation point\")\n", "axes[0].set_title(\"Untruncated posterior\")\n", "axes[0].legend()\n", "\n", "axes[1].hist(\n", " samples_thinned[samples_thinned < true_high].copy(),\n", " label=\"Tail of untruncated posterior\",\n", " bins=20,\n", " density=True,\n", ")\n", "axes[1].hist(true_x.copy(), label=\"Observed, truncated data\", density=True, alpha=0.5)\n", "axes[1].axvline(true_high, linestyle=\":\", color=\"k\", label=\"Truncation point\")\n", "axes[1].set_title(\"Comparison to observed data\")\n", "axes[1].legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "8ed9169a", "metadata": {}, "source": [ "#### 5.3 Example: Left-truncated Poisson \n", "\n", "As a final example, we now implement a left-truncated Poisson distribution.\n", "Note that a right-truncated Poisson could be reformulated as a particular\n", "case of a categorical distribution, so we focus on the less trivial case.\n", "\n", "**Class attributes**\n", "\n", "For a truncated Poisson we need two parameters, the `rate` of the original Poisson\n", "distribution and a `low` parameter to indicate the truncation point.\n", "As this is a discrete distribution, we need to clarify whether or not the truncation point is included\n", "in the support. In this tutorial, we'll take the convention that the truncation point `low`\n", "_is_ part of the support.\n", "\n", "The `low` parameter has to be given a 'non-negative integer' constraint. As it is a discrete parameter, it will not be possible to do inference for this parameter using [NUTS](https://num.pyro.ai/en/stable/mcmc.html#nuts). This is likely not a problem since the truncation point is often known in advance. However, if we really must infer the `low` parameter, it is possible to do so with [DiscreteHMCGibbs](https://num.pyro.ai/en/stable/mcmc.html#discretehmcgibbs) though one is limited to using priors with enumerate support.\n", "\n", "Like in the case of the truncated normal, the support of this distribution will be defined as a property and not as a class attribute because it depends on the specific value of the `low` parameter.\n", "```python\n", "class LeftTruncatedPoisson:\n", " arg_constraints = {\n", " \"low\": constraints.nonnegative_integer,\n", " \"rate\": constraints.positive,\n", " }\n", " \n", " # ... \n", " @constraints.dependent_property(is_discrete=True)\n", " def support(self):\n", " return constraints.integer_greater_than(self.low - 1)\n", "```\n", "\n", "The `is_discrete` argument passed in the `dependent_property` decorator is used to tell the inference algorithms which variables are discrete latent variables.\n", "\n", "**The** `__init__` **method**\n", "\n", "Here we just follow the same pattern as in the previous example.\n", "```python\n", " # ...\n", " def __init__(self, rate=1.0, low=0, validate_args=None):\n", " batch_shape = lax.broadcast_shapes(\n", " jnp.shape(low), jnp.shape(rate)\n", " )\n", " self.low, self.rate = promote_shapes(low, rate)\n", " super().__init__(batch_shape, validate_args=validate_args)\n", " # ...\n", "```\n", "\n", "**The** `log_prob` **method**\n", "\n", "The logic is very similar to the truncated normal case. But this time we are truncating on the left, so the correct normalization is the complementary cumulative density:\n", "\n", "$$\n", "\\begin{align}\n", "M = \\sum_{n=L}^{\\infty} p_Y(n) = 1 - \\sum_{n=0}^{L - 1} p_Y(n) = 1 - F_Y(L - 1)\n", "\\end{align}\n", "$$\n", "\n", "For the code, we can rely on the `poisson` module that lives inside `jax.scipy.stats`.\n", "\n", "```python\n", " # ...\n", " def log_prob(self, value):\n", " m = 1 - poisson.cdf(self.low - 1, self.rate)\n", " log_p = poisson.logpmf(value, self.rate)\n", " return jnp.where(value >= self.low, log_p - jnp.log(m), -jnp.inf)\n", " # ...\n", "```\n", "\n", "**The** `sample` **method**\n", "\n", "Inverse-transform sampling also works for discrete distributions. The \"inverse\" cdf of a discrete distribution being defined as:\n", "\n", "$$\n", "\\begin{align}\n", "F^{-1}(u) = \\max\\left\\{n\\in \\mathbb{N} \\rvert F(n) \\lt u\\right\\}\n", "\\end{align}\n", "$$\n", "\n", "Or, in plain English, $F^{-1}(u)$ is the highest number for which the cumulative density is less than $u$.\n", "However, there's currently no implementation of $F^{-1}$ for the Poisson distribution in Jax (at least, at the moment of writing this tutorial). We have to rely on our own implementation. Fortunately, we can take advantage of the discrete nature of the distribution and easily implement a \"brute-force\" version that will work for most cases. The brute force approach consists of simply scanning all non-negative integers in order, one by one, until the value of the cumulative density exceeds the argument $u$. The implicit requirement is that we need a way to evaluate the cumulative density for the truncated distribution, but we can calculate that:\n", "\n", "$$\n", "\\begin{align}\n", "F_Z(z) &= \\sum_{n=0}^z p_z(n)\\newline\n", " &= \\frac{1}{M}\\sum_{n=L}^z p_Y(n)\\quad \\text{assuming $z >= L$}\\newline\n", " &= \\frac{1}{M}\\left(\\sum_{n=0}^z p_Y(n) - \\sum_{n=0}^{L-1}p_Y(n)\\right)\\newline\n", " &= \\frac{1}{M}\\left(F_Y(z) - F_Y (L-1)\\right)\n", "\\end{align}\n", "$$\n", "\n", "And, of course, the value of $F_Z(z)$ is equal to zero if $z < L$.\n", "(As in the previous example, we are using $Y$ to denote the original, un-truncated variable, and we are using $Z$ to denote the truncated variable)\n", "\n", "```python\n", " # ...\n", " def sample(self, key, sample_shape=()):\n", " shape = sample_shape + self.batch_shape\n", " minval = jnp.finfo(jnp.result_type(float)).tiny\n", " u = random.uniform(key, shape, minval=minval)\n", " return self.icdf(u)\n", "\n", " def icdf(self, u):\n", " def cond_fn(val):\n", " n, cdf = val\n", " return jnp.any(cdf < u)\n", "\n", " def body_fn(val):\n", " n, cdf = val\n", " n_new = jnp.where(cdf < u, n + 1, n)\n", " return n_new, self.cdf(n_new)\n", " \n", " low = self.low * jnp.ones_like(u)\n", " cdf = self.cdf(low)\n", " n, _ = lax.while_loop(cond_fn, body_fn, (low, cdf))\n", " return n.astype(jnp.result_type(int))\n", "\n", " def cdf(self, value):\n", " m = 1 - poisson.cdf(self.low - 1, self.rate)\n", " f = poisson.cdf(value, self.rate) - poisson.cdf(self.low - 1, self.rate)\n", " return jnp.where(k >= self.low, f / m, 0)\n", "```\n", "\n", "A few comments with respect to the above implementation:\n", "* Even with double precision, if `rate` is much less than `low`, the above code will not work. Due to numerical limitations, one obtains that `poisson.cdf(low - 1, rate)` is equal (or very close) to `1.0`. This makes it impossible to re-weight the distribution accurately because the normalization constant would be `0.0`.\n", "* The brute-force `icdf` is of course very slow, particularly when `rate` is high. If you need faster sampling, one option would be to rely on a faster search algorithm. For example:\n", "\n", "```python\n", "def icdf_faster(self, u):\n", " num_bins = 200 # Choose a reasonably large value\n", " bins = jnp.arange(num_bins)\n", " cdf = self.cdf(bins)\n", " indices = jnp.searchsorted(cdf, u)\n", " return bins[indices]\n", "```\n", "\n", "The obvious limitation here is that the number of bins has to be fixed a priori (jax does not allow for dynamically sized arrays). Another option would be to rely on an _approximate_ implementation, as proposed in [this article](https://people.maths.ox.ac.uk/gilesm/codes/poissinv/paper.pdf).\n", "\n", "* Yet another alternative for the `icdf` is to rely on `scipy`'s implementation and make use of Jax's `host_callback` module. This feature allows you to use Python functions without having to code them in `Jax`. This means that we can simply make use of `scipy`'s implementation of the Poisson ICDF! From the last equation, we can write the _truncated_ icdf as:\n", "\n", "$$\n", "\\begin{align}\n", "F_Z^{-1}(u) = F_Y^{-1}(Mu + F_Y(L-1))\n", "\\end{align}\n", "$$\n", "\n", "And in python:\n", "\n", "```python\n", " def scipy_truncated_poisson_icdf(args): # Note: all arguments are passed inside a tuple\n", " rate, low, u = args\n", " rate = np.asarray(rate)\n", " low = np.asarray(low)\n", " u = np.asarray(u)\n", " density = sp_poisson(rate)\n", " low_cdf = density.cdf(low - 1)\n", " normalizer = 1.0 - low_cdf\n", " x = normalizer * u + low_cdf\n", " return density.ppf(x)\n", "```\n", "\n", "In principle, it wouldn't be possible to use the above function in our NumPyro distribution because it is not coded in Jax. The `jax.experimental.host_callback.call` function solves precisely that problem. The code below shows you how to use it, but keep in mind that this is currently an experimental feature so you should expect changes to the module. See the `host_callback` [docs](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) for more details.\n", "\n", "```python\n", " # ...\n", " def icdf_scipy(self, u):\n", " result_shape = jax.ShapeDtypeStruct(\n", " u.shape,\n", " jnp.result_type(float) # int type not currently supported\n", " )\n", " result = jax.experimental.host_callback.call(\n", " scipy_truncated_poisson_icdf,\n", " (self.rate, self.low, u),\n", " result_shape=result_shape\n", " )\n", " return result.astype(jnp.result_type(int))\n", " # ...\n", "```\n", "\n", "Putting it all together, the implementation is as below:" ] }, { "cell_type": "code", "execution_count": 29, "id": "13a86497", "metadata": {}, "outputs": [], "source": [ "def scipy_truncated_poisson_icdf(args): # Note: all arguments are passed inside a tuple\n", " rate, low, u = args\n", " rate = np.asarray(rate)\n", " low = np.asarray(low)\n", " u = np.asarray(u)\n", " density = sp_poisson(rate)\n", " low_cdf = density.cdf(low - 1)\n", " normalizer = 1.0 - low_cdf\n", " x = normalizer * u + low_cdf\n", " return density.ppf(x)\n", "\n", "\n", "class LeftTruncatedPoisson(Distribution):\n", " \"\"\"\n", " A truncated Poisson distribution.\n", " :param numpy.ndarray low: lower bound at which truncation happens\n", " :param numpy.ndarray rate: rate of the Poisson distribution.\n", " \"\"\"\n", "\n", " arg_constraints = {\n", " \"low\": constraints.nonnegative_integer,\n", " \"rate\": constraints.positive,\n", " }\n", "\n", " def __init__(self, rate=1.0, low=0, validate_args=None):\n", " batch_shape = lax.broadcast_shapes(jnp.shape(low), jnp.shape(rate))\n", " self.low, self.rate = promote_shapes(low, rate)\n", " super().__init__(batch_shape, validate_args=validate_args)\n", "\n", " def log_prob(self, value):\n", " m = 1 - poisson.cdf(self.low - 1, self.rate)\n", " log_p = poisson.logpmf(value, self.rate)\n", " return jnp.where(value >= self.low, log_p - jnp.log(m), -jnp.inf)\n", "\n", " def sample(self, key, sample_shape=()):\n", " shape = sample_shape + self.batch_shape\n", " float_type = jnp.result_type(float)\n", " minval = jnp.finfo(float_type).tiny\n", " u = random.uniform(key, shape, minval=minval)\n", " # return self.icdf(u) # Brute force\n", " # return self.icdf_faster(u) # For faster sampling.\n", " return self.icdf_scipy(u) # Using `host_callback`\n", "\n", " def icdf(self, u):\n", " def cond_fn(val):\n", " n, cdf = val\n", " return jnp.any(cdf < u)\n", "\n", " def body_fn(val):\n", " n, cdf = val\n", " n_new = jnp.where(cdf < u, n + 1, n)\n", " return n_new, self.cdf(n_new)\n", "\n", " low = self.low * jnp.ones_like(u)\n", " cdf = self.cdf(low)\n", " n, _ = lax.while_loop(cond_fn, body_fn, (low, cdf))\n", " return n.astype(jnp.result_type(int))\n", "\n", " def icdf_faster(self, u):\n", " num_bins = 200 # Choose a reasonably large value\n", " bins = jnp.arange(num_bins)\n", " cdf = self.cdf(bins)\n", " indices = jnp.searchsorted(cdf, u)\n", " return bins[indices]\n", "\n", " def icdf_scipy(self, u):\n", " result_shape = jax.ShapeDtypeStruct(u.shape, jnp.result_type(float))\n", " result = jax.experimental.host_callback.call(\n", " scipy_truncated_poisson_icdf,\n", " (self.rate, self.low, u),\n", " result_shape=result_shape,\n", " )\n", " return result.astype(jnp.result_type(int))\n", "\n", " def cdf(self, value):\n", " m = 1 - poisson.cdf(self.low - 1, self.rate)\n", " f = poisson.cdf(value, self.rate) - poisson.cdf(self.low - 1, self.rate)\n", " return jnp.where(value >= self.low, f / m, 0)\n", "\n", " @constraints.dependent_property(is_discrete=True)\n", " def support(self):\n", " return constraints.integer_greater_than(self.low - 1)" ] }, { "cell_type": "markdown", "id": "ba41bcd9", "metadata": {}, "source": [ "Let's try it out!" ] }, { "cell_type": "code", "execution_count": 30, "id": "3356242e", "metadata": {}, "outputs": [], "source": [ "def discrete_distplot(samples, ax=None, **kwargs):\n", " \"\"\"\n", " Utility function for plotting the samples as a barplot.\n", " \"\"\"\n", " x, y = np.unique(samples, return_counts=True)\n", " y = y / sum(y)\n", " if ax is None:\n", " ax = plt.gca()\n", "\n", " ax.bar(x, y, **kwargs)\n", " return ax" ] }, { "cell_type": "code", "execution_count": 31, "id": "6103409f", "metadata": {}, "outputs": [], "source": [ "def truncated_poisson_model(num_observations, x=None):\n", " low = numpyro.sample(\"low\", dist.Categorical(0.2 * jnp.ones((5,))))\n", " rate = numpyro.sample(\"rate\", dist.LogNormal(1, 1))\n", " with numpyro.plate(\"observations\", num_observations):\n", " numpyro.sample(\"x\", LeftTruncatedPoisson(rate, low), obs=x)" ] }, { "cell_type": "markdown", "id": "9714e0e8", "metadata": {}, "source": [ "**Prior samples**" ] }, { "cell_type": "code", "execution_count": 32, "id": "4c28722b", "metadata": {}, "outputs": [], "source": [ "# -- prior samples\n", "num_observations = 1000\n", "num_prior_samples = 100\n", "prior = Predictive(truncated_poisson_model, num_samples=num_prior_samples)\n", "prior_samples = prior(PRIOR_RNG, num_observations)" ] }, { "cell_type": "markdown", "id": "14f6d625", "metadata": {}, "source": [ "**Inference**" ] }, { "cell_type": "markdown", "id": "52835c94", "metadata": {}, "source": [ "As in the case for the truncated normal, here it is better to replace\n", "the prior on the `low` parameter so that it is consistent with the observed data.\n", "We'd like to have a categorical prior on `low` (so that we can use [DiscreteHMCGibbs](https://num.pyro.ai/en/stable/mcmc.html#discretehmcgibbs))\n", "whose highest category is equal to the minimum value of `x` (so that prior and data are consistent).\n", "However, we have to be careful in the way we write such model because Jax does not allow for dynamically sized arrays. A simple way of coding this model is to simply specify the number of categories as an argument:" ] }, { "cell_type": "code", "execution_count": 33, "id": "f2600b04", "metadata": {}, "outputs": [], "source": [ "def truncated_poisson_model(num_observations, x=None, k=5):\n", " zeros = jnp.zeros((k,))\n", " low = numpyro.sample(\"low\", dist.Categorical(logits=zeros))\n", " rate = numpyro.sample(\"rate\", dist.LogNormal(1, 1))\n", " with numpyro.plate(\"observations\", num_observations):\n", " numpyro.sample(\"x\", LeftTruncatedPoisson(rate, low), obs=x)" ] }, { "cell_type": "code", "execution_count": 34, "id": "2c97134f", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD7CAYAAABkO19ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQ50lEQVR4nO3df6zddX3H8edrRarTCAh3RqFd66jbSnRsXotLnDMYWZnRuqxoMdvqwlKX2GyLmq1uCWLnH7BsNkvGlnWDibitEDa3m1DXEDBxMYq9IIKFoRdEacek/HTMABbe++N82U6Ot73fcn/yuc9HctLv9/P9fM/9fPJNX+dzvj8+J1WFJKldP7LYDZAkzS+DXpIaZ9BLUuMMeklqnEEvSY0z6CWpcb2CPsnGJHcnmUqyY5rtb0lya5IjSTYPlZ+d5EtJDiS5Pcl757LxkqSZZab76JOsAL4BvB04COwHLqyqO4fqrAFeDnwEmKiq67ry1wJVVd9M8mrgFuCnq+qxue+KJGk6J/SoswGYqqp7AZLsATYB/xf0VXVft+3Z4R2r6htDy/+Z5EFgDHjsaH/stNNOqzVr1vTugCQJbrnlloeqamy6bX2C/nTg/qH1g8A5x9uIJBuAE4F7jlVvzZo1TE5OHu/bS9KyluTbR9u2IBdjk7wKuBr4zap6dprt25JMJpk8fPjwQjRJkpaNPkF/CFg1tH5GV9ZLkpcD1wN/VFVfnq5OVe2uqvGqGh8bm/abhyTpeeoT9PuBdUnWJjkR2AJM9Hnzrv5ngU8/d4FWkrSwZgz6qjoCbAf2AXcB11bVgSQ7k7wLIMkbkxwELgD+OsmBbvf3AG8B3p/ktu519nx0RJI0vRlvr1xo4+Pj5cVYSTo+SW6pqvHptvlkrCQ1zqCXpMYZ9JLUOINekhrX58lYaVlas+P6XvXuu/Qd89wSaXYc0UtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMb1CvokG5PcnWQqyY5ptr8lya1JjiTZPLJta5Jvdq+tc9VwSVI/MwZ9khXA5cD5wHrgwiTrR6p9B3g/8A8j+74C+BhwDrAB+FiSU2bfbElSX31G9BuAqaq6t6qeBvYAm4YrVNV9VXU78OzIvr8E3FBVj1TVo8ANwMY5aLckqac+QX86cP/Q+sGurI/Z7CtJmgNL4mJskm1JJpNMHj58eLGbI0lN6RP0h4BVQ+tndGV99Nq3qnZX1XhVjY+NjfV8a0lSH32Cfj+wLsnaJCcCW4CJnu+/DzgvySndRdjzujJJ0gI5YaYKVXUkyXYGAb0CuLKqDiTZCUxW1USSNwKfBU4B3pnk41V1VlU9kuSPGXxYAOysqkfmqS96AVuz4/pe9e679B3z3BKpPTMGPUBV7QX2jpRdPLS8n8Fpmen2vRK4chZtlCTNwpK4GCtJmj8GvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWpcr/vopaXGB6yk/hzRS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIa5w+PaNno82Ml/lCJWuSIXpIaZ9BLUuN6BX2SjUnuTjKVZMc021cmuabbfnOSNV35i5JcleSOJHcl+egct1+SNIMZgz7JCuBy4HxgPXBhkvUj1S4CHq2qM4FdwGVd+QXAyqp6HfAG4APPfQhIkhZGnxH9BmCqqu6tqqeBPcCmkTqbgKu65euAtyUJUMBLk5wAvAR4GvjenLRcktRLn6A/Hbh/aP1gVzZtnao6AjwOnMog9P8HeAD4DvCnVfXILNssSToO830xdgPwDPBqYC3w4SSvGa2UZFuSySSThw8fnucmSdLy0ifoDwGrhtbP6MqmrdOdpjkJeBh4H/BvVfWDqnoQ+CIwPvoHqmp3VY1X1fjY2Njx90KSdFR9gn4/sC7J2iQnAluAiZE6E8DWbnkzcFNVFYPTNecCJHkp8CbgP+ai4ZKkfmYM+u6c+3ZgH3AXcG1VHUiyM8m7umpXAKcmmQI+BDx3C+blwMuSHGDwgfF3VXX7XHdCknR0vaZAqKq9wN6RsouHlp9kcCvl6H5PTFcuSVo4PhkrSY0z6CWpcc5eKc2RPrNjgjNkauE5opekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOKdAkF5AnGZBz4cjeklqnEEvSY3z1I3mhacYpKXDEb0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxvUK+iQbk9ydZCrJjmm2r0xyTbf95iRrhra9PsmXkhxIckeSF89h+yVJM5gx6JOsAC4HzgfWAxcmWT9S7SLg0ao6E9gFXNbtewLwGeC3q+os4K3AD+as9ZKkGfUZ0W8Apqrq3qp6GtgDbBqpswm4qlu+DnhbkgDnAbdX1dcAqurhqnpmbpouSeqjT9CfDtw/tH6wK5u2TlUdAR4HTgVeC1SSfUluTfL7s2+yJOl4zPc0xScAbwbeCHwfuDHJLVV143ClJNuAbQCrV6+e5yZJ0vLSZ0R/CFg1tH5GVzZtne68/EnAwwxG/1+oqoeq6vvAXuDnRv9AVe2uqvGqGh8bGzv+XkiSjqpP0O8H1iVZm+REYAswMVJnAtjaLW8GbqqqAvYBr0vyo90HwC8Cd85N0yVJfcx46qaqjiTZziC0VwBXVtWBJDuByaqaAK4Ark4yBTzC4MOAqno0yScZfFgUsLeq+v30kCRpTvQ6R19Vexmcdhkuu3ho+UnggqPs+xkGt1hKkhaBT8ZKUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktS4+Z7rRo1Ys2Pm59zuu/QdC9ASScfLEb0kNc4RvbRI+nxLAr8pafYc0UtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGeXul1DBv4RQ4opek5jmiX4Yc5UnLiyN6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuN6BX2SjUnuTjKVZMc021cmuabbfnOSNSPbVyd5IslH5qjdkqSeZgz6JCuAy4HzgfXAhUnWj1S7CHi0qs4EdgGXjWz/JPC52TdXknS8+ozoNwBTVXVvVT0N7AE2jdTZBFzVLV8HvC1JAJK8G/gWcGBOWixJOi59gv504P6h9YNd2bR1quoI8DhwapKXAX8AfHz2TZUkPR/zfTH2EmBXVT1xrEpJtiWZTDJ5+PDheW6SJC0vfWavPASsGlo/oyubrs7BJCcAJwEPA+cAm5P8CXAy8GySJ6vqL4Z3rqrdwG6A8fHxeh79kCQdRZ+g3w+sS7KWQaBvAd43UmcC2Ap8CdgM3FRVBfzCcxWSXAI8MRrykpYOp7Bu04xBX1VHkmwH9gErgCur6kCSncBkVU0AVwBXJ5kCHmHwYSBJWgJ6/fBIVe0F9o6UXTy0/CRwwQzvccnzaJ8kaZb8hakG+HVb0rE4BYIkNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcT4wtQT5AJSkueSIXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhrXK+iTbExyd5KpJDum2b4yyTXd9puTrOnK357kliR3dP+eO8ftlyTNYMagT7ICuBw4H1gPXJhk/Ui1i4BHq+pMYBdwWVf+EPDOqnodsBW4eq4aLknqp89vxm4ApqrqXoAke4BNwJ1DdTYBl3TL1wF/kSRV9dWhOgeAlyRZWVVPzbrlkpaEPr9x7O8bL64+p25OB+4fWj/YlU1bp6qOAI8Dp47U+VXgVkNekhZWnxH9rCU5i8HpnPOOsn0bsA1g9erVC9EkSVo2+ozoDwGrhtbP6MqmrZPkBOAk4OFu/Qzgs8BvVNU90/2BqtpdVeNVNT42NnZ8PZAkHVOfoN8PrEuyNsmJwBZgYqTOBIOLrQCbgZuqqpKcDFwP7KiqL85RmyVJx2HGoO/OuW8H9gF3AddW1YEkO5O8q6t2BXBqkingQ8Bzt2BuB84ELk5yW/f6sTnvhSTpqHqdo6+qvcDekbKLh5afBC6YZr9PAJ+YZRslSbPgk7GS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQsyBcJy1mfCJ3DSJ0nzx6CXtKAc/Cw8T91IUuMMeklqnEEvSY3zHL2kJc1z+rPniF6SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1ztsrj5O3eklLm/9Hf5gjeklqnCN6STpOL7RvDY7oJalxjuglLXt9RuhLZXT+fDiil6TGGfSS1DiDXpIaZ9BLUuN6XYxNshH4c2AF8LdVdenI9pXAp4E3AA8D762q+7ptHwUuAp4Bfqeq9s1Z66fR+kUVSS88i3075owj+iQrgMuB84H1wIVJ1o9Uuwh4tKrOBHYBl3X7rge2AGcBG4G/7N5PkrRA+py62QBMVdW9VfU0sAfYNFJnE3BVt3wd8LYk6cr3VNVTVfUtYKp7P0nSAukT9KcD9w+tH+zKpq1TVUeAx4FTe+4rSZpHqapjV0g2Axur6re69V8Hzqmq7UN1vt7VOdit3wOcA1wCfLmqPtOVXwF8rqquG/kb24Bt3epPAnfPvmtL2mnAQ4vdiEWwHPttn5ePxe73j1fV2HQb+lyMPQSsGlo/oyubrs7BJCcAJzG4KNtnX6pqN7C7R1uakGSyqsYXux0LbTn22z4vH0u5331O3ewH1iVZm+REBhdXJ0bqTABbu+XNwE01+KowAWxJsjLJWmAd8JW5abokqY8ZR/RVdSTJdmAfg9srr6yqA0l2ApNVNQFcAVydZAp4hMGHAV29a4E7gSPAB6vqmXnqiyRpGjOeo9fcS7KtO121rCzHftvn5WMp99ugl6TGOQWCJDXOoJ9nSa5M8mB3C+pzZa9IckOSb3b/nrKYbZwPR+n3JUkOJbmte/3yYrZxriVZleTzSe5MciDJ73blzR7vY/S52WOd5MVJvpLka12fP96Vr01yc5KpJNd0N68sCQb9/PsUg+kfhu0AbqyqdcCN3XprPsUP9xtgV1Wd3b32LnCb5tsR4MNVtR54E/DBbhqQlo/30foM7R7rp4Bzq+pngLOBjUnexGDql13dVDCPMpgaZkkw6OdZVX2BwZ1Iw4anjLgKePdCtmkhHKXfTauqB6rq1m75v4G7GDwJ3uzxPkafm1UDT3SrL+peBZzLYAoYWGLH2aBfHK+sqge65f8CXrmYjVlg25Pc3p3aaeYUxqgka4CfBW5mmRzvkT5Dw8c6yYoktwEPAjcA9wCPdVPAwBKb7sWgX2Tdg2XL5danvwJ+gsHX3QeAP1vU1syTJC8D/gn4var63vC2Vo/3NH1u+lhX1TNVdTaDp/03AD+1uC06NoN+cXw3yasAun8fXOT2LIiq+m73H+RZ4G9ocCbTJC9iEHh/X1X/3BU3fbyn6/NyONYAVfUY8Hng54GTuylg4CjTvSwWg35xDE8ZsRX410Vsy4J5Luw6vwJ8/Wh1X4i6qbmvAO6qqk8ObWr2eB+tzy0f6yRjSU7ull8CvJ3BtYnPM5gCBpbYcfaBqXmW5B+BtzKY2e67wMeAfwGuBVYD3wbeU1VNXbg8Sr/fyuCrfAH3AR8YOnf9gpfkzcC/A3cAz3bFf8jgnHWTx/sYfb6QRo91ktczuNi6gsFg+dqq2pnkNQx+r+MVwFeBX6uqpxavpf/PoJekxnnqRpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktS4/wXXazAXfWWPzwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Take any prior sample as the true process.\n", "true_idx = 6\n", "true_low = prior_samples[\"low\"][true_idx]\n", "true_rate = prior_samples[\"rate\"][true_idx]\n", "true_x = prior_samples[\"x\"][true_idx]\n", "discrete_distplot(true_x.copy());" ] }, { "cell_type": "markdown", "id": "4b5a1ca8", "metadata": {}, "source": [ "To do inference, we set `k = x.min() + 1`. Note also the use of [DiscreteHMCGibbs](https://num.pyro.ai/en/stable/mcmc.html#discretehmcgibbs):" ] }, { "cell_type": "code", "execution_count": 35, "id": "c6fdc77e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:04<00:00, 808.70it/s, 3 steps of size 9.58e-01. acc. prob=0.93]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5916.30it/s, 3 steps of size 9.14e-01. acc. prob=0.93]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5082.16it/s, 3 steps of size 9.91e-01. acc. prob=0.92]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 6511.68it/s, 3 steps of size 8.66e-01. acc. prob=0.94]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " low 4.13 2.43 4.00 0.00 7.00 7433.79 1.00\n", " rate 18.16 0.14 18.16 17.96 18.40 3074.46 1.00\n", "\n" ] } ], "source": [ "mcmc = MCMC(DiscreteHMCGibbs(NUTS(truncated_poisson_model)), **MCMC_KWARGS)\n", "mcmc.run(MCMC_RNG, num_observations, true_x, k=true_x.min() + 1)\n", "mcmc.print_summary()" ] }, { "cell_type": "code", "execution_count": 36, "id": "f59c9431", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(18.2091848, dtype=float64)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "true_rate" ] }, { "cell_type": "markdown", "id": "f4be5ea6", "metadata": {}, "source": [ "As before, one needs to be extra careful when estimating the truncation point.\n", "If the truncation point is known is best to provide it." ] }, { "cell_type": "code", "execution_count": 37, "id": "b93149d1", "metadata": {}, "outputs": [], "source": [ "model_with_known_low = numpyro.handlers.condition(\n", " truncated_poisson_model, {\"low\": true_low}\n", ")" ] }, { "cell_type": "markdown", "id": "4b0340ac", "metadata": {}, "source": [ "And note we can use [NUTS](https://num.pyro.ai/en/stable/mcmc.html#nuts) directly because there's no need to infer any discrete parameters." ] }, { "cell_type": "code", "execution_count": 38, "id": "bcbcc6d8", "metadata": {}, "outputs": [], "source": [ "mcmc = MCMC(\n", " NUTS(model_with_known_low),\n", " **MCMC_KWARGS,\n", ")" ] }, { "cell_type": "code", "execution_count": 39, "id": "b3c6f6ab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:03<00:00, 1185.13it/s, 1 steps of size 9.18e-01. acc. prob=0.93]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5786.32it/s, 3 steps of size 1.00e+00. acc. prob=0.92]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5919.13it/s, 1 steps of size 8.62e-01. acc. prob=0.94]\n", "sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7562.36it/s, 3 steps of size 9.01e-01. acc. prob=0.93]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " rate 18.17 0.13 18.17 17.95 18.39 3406.81 1.00\n", "\n", "Number of divergences: 0\n" ] } ], "source": [ "mcmc.run(MCMC_RNG, num_observations, true_x)\n", "mcmc.print_summary()" ] }, { "cell_type": "markdown", "id": "55858b35", "metadata": {}, "source": [ "**Removing the truncation**" ] }, { "cell_type": "code", "execution_count": 40, "id": "6614bf7e", "metadata": {}, "outputs": [], "source": [ "model_without_truncation = numpyro.handlers.condition(\n", " truncated_poisson_model,\n", " {\"low\": 0},\n", ")\n", "pred = Predictive(model_without_truncation, posterior_samples=mcmc.get_samples())\n", "pred_samples = pred(PRED_RNG, num_observations)\n", "thinned_samples = pred_samples[\"x\"][::500]" ] }, { "cell_type": "code", "execution_count": 41, "id": "deeb8164", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPj0lEQVR4nO3df6zddX3H8edrrVSGCyh0RgF3u4Fbihqnd3VL3GYkmjIzq1nR4n6whAVIbOZClln3B0OiCRgnWSJx6VZcxW2FoG43o67TYaIzDrkgihXJrlBHOyblh7guQSy+98f5Njs5uT9O6bn3nvPp85E09/v9fD+n930+oa/z4fP9cVJVSJLa9ROrXYAkaXkZ9JLUOINekhpn0EtS4wx6SWrc2tUuYNBZZ51VU1NTq12GJE2Uu++++7GqWj/fsbEL+qmpKWZnZ1e7DEmaKEm+u9Axl24kqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxY3dnrHSipnbcvujxA9e9ZYUqkcaDQa+T1mIfCH4YqCUu3UhS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGeXmlJorXyEvHzxm9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapyXV0qL8AmXaoEzeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktQ4g16SGud19NIJ8tHJGnfO6CWpcUMFfZLNSR5IMpdkxzzH1yW5pTt+Z5Kprv15SXYnuS/J/UneN+L6JUlLWDLok6wBbgQuAjYClyTZONDtMuDJqjoPuAG4vmu/GFhXVa8EXgtccexDQJK0MoaZ0W8C5qrqwap6BtgDbBnoswXY3W3fBlyYJEABpyVZC5wKPAP8YCSVS5KGMkzQnw083Ld/sGubt09VHQWeAs6kF/r/CzwC/Cfw4ap6YvAXJLk8yWyS2cOHDx/3m5AkLWy5r7rZBDwLvBR4IfClJJ+vqgf7O1XVTmAnwPT0dC1zTRpTPilSWh7DzOgPAef27Z/Ttc3bp1umOR14HHgX8M9V9aOqehT4MjB9okVLkoY3TNDfBZyfZEOSU4BtwMxAnxng0m57K3BHVRW95Zo3AiQ5Dfhl4NujKFySNJwlg75bc98O7APuB26tqv1Jrk3y1q7bLuDMJHPAVcCxSzBvBF6QZD+9D4yPV9U3Rv0mJEkLG2qNvqr2AnsH2q7u236a3qWUg687Ml+7JGnleGesJDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGDfVVgpJO3NSO2xc8duC6t6xgJTrZOKOXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY3zMcVaET6iV1o9zuglqXEGvSQ1bqigT7I5yQNJ5pLsmOf4uiS3dMfvTDLVd+xVSb6SZH+S+5I8f4T1S5KWsGTQJ1kD3AhcBGwELkmycaDbZcCTVXUecANwfffatcAngSur6gLgDcCPRla9JGlJw8zoNwFzVfVgVT0D7AG2DPTZAuzutm8DLkwS4M3AN6rq6wBV9XhVPTua0iVJwxgm6M8GHu7bP9i1zdunqo4CTwFnAi8HKsm+JPck+ZP5fkGSy5PMJpk9fPjw8b4HSdIilvtk7Frg9cBvdz/fnuTCwU5VtbOqpqtqev369ctckiSdXIa5jv4QcG7f/jld23x9Dnbr8qcDj9Ob/X+xqh4DSLIXeA3wrydYt9Qk7zfQchhmRn8XcH6SDUlOAbYBMwN9ZoBLu+2twB1VVcA+4JVJfrL7APh14FujKV2SNIwlZ/RVdTTJdnqhvQa4qar2J7kWmK2qGWAXcHOSOeAJeh8GVNWTST5C78OigL1VtfCURZI0ckM9AqGq9gJ7B9qu7tt+Grh4gdd+kt4llpKkVeCdsZLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxg31xSPSQhb7jlPwe06lceCMXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktQ4g16SGuedsdKE8W5kHS9n9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGDRX0STYneSDJXJId8xxfl+SW7vidSaYGjr8syZEkfzyiuiVJQ1oy6JOsAW4ELgI2Apck2TjQ7TLgyao6D7gBuH7g+EeAz554uZKk4zXMjH4TMFdVD1bVM8AeYMtAny3A7m77NuDCJAFI8jbgIWD/SCqWJB2XYYL+bODhvv2DXdu8farqKPAUcGaSFwDvBd5/4qVKkp6L5T4Zew1wQ1UdWaxTksuTzCaZPXz48DKXJEknl2GeXnkIOLdv/5yubb4+B5OsBU4HHgdeB2xN8iHgDODHSZ6uqo/2v7iqdgI7Aaanp+s5vA9J0gKGCfq7gPOTbKAX6NuAdw30mQEuBb4CbAXuqKoCfvVYhyTXAEcGQ16StLyWDPqqOppkO7APWAPcVFX7k1wLzFbVDLALuDnJHPAEvQ8DSdIYGOqLR6pqL7B3oO3qvu2ngYuX+DuueQ71SZJOkHfGSlLjDHpJapxBL0mN88vBtaDFvoTaL6CWJoczeklqnEEvSY1z6UZq0GLLbuDS28nGGb0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWrc2tUuQNLqmdpx+4LHDlz3lhWsRMvJGb0kNc6gl6TGGfSS1DiDXpIa58nYk5Qn4aSThzN6SWrcUEGfZHOSB5LMJdkxz/F1SW7pjt+ZZKprf1OSu5Pc1/1844jrlyQtYcmgT7IGuBG4CNgIXJJk40C3y4Anq+o84Abg+q79MeA3q+qVwKXAzaMqXJI0nGFm9JuAuap6sKqeAfYAWwb6bAF2d9u3ARcmSVV9rar+q2vfD5yaZN0oCpckDWeYoD8beLhv/2DXNm+fqjoKPAWcOdDnt4B7quqHg78gyeVJZpPMHj58eNjaJUlDWJGTsUkuoLecc8V8x6tqZ1VNV9X0+vXrV6IkSTppDBP0h4Bz+/bP6drm7ZNkLXA68Hi3fw7wGeD3quo7J1qwJOn4DBP0dwHnJ9mQ5BRgGzAz0GeG3slWgK3AHVVVSc4Abgd2VNWXR1SzJOk4LBn03Zr7dmAfcD9wa1XtT3Jtkrd23XYBZyaZA64Cjl2CuR04D7g6yb3dn58e+buQJC1oqDtjq2ovsHeg7eq+7aeBi+d53QeAD5xgjZJWkXdRTz7vjJWkxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekho31LNuNDkWey4J+GwS6WRk0Es6YU4wxptLN5LUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DjvjJW0Yha7g9a7Z5ePM3pJapwz+gni80QkPRfO6CWpcQa9JDXOoJekxrlGL2mseGXO6Dmjl6TGGfSS1DiDXpIa5xr9mHBdUhqe95QcH2f0ktQ4g16SGufSzQpwWUbSahoq6JNsBv4CWAP8dVVdN3B8HfAJ4LXA48A7q+pAd+x9wGXAs8AfVtW+kVUvSQtwHf//LRn0SdYANwJvAg4CdyWZqapv9XW7DHiyqs5Lsg24Hnhnko3ANuAC4KXA55O8vKqeHfUbkaTn4mT4P+5hZvSbgLmqehAgyR5gC9Af9FuAa7rt24CPJknXvqeqfgg8lGSu+/u+Mpryl88wswFnDNLJYZgPg3HOg1TV4h2SrcDmqvqDbv93gddV1fa+Pt/s+hzs9r8DvI5e+P97VX2ya98FfLaqbhv4HZcDl3e7Pw88cOJvbeTOAh5b7SKeg0mtGya3duteeZNa+yjr/pmqWj/fgbE4GVtVO4Gdq13HYpLMVtX0atdxvCa1bpjc2q175U1q7StV9zCXVx4Czu3bP6drm7dPkrXA6fROyg7zWknSMhom6O8Czk+yIckp9E6uzgz0mQEu7ba3AndUb01oBtiWZF2SDcD5wFdHU7okaRhLLt1U1dEk24F99C6vvKmq9ie5FpitqhlgF3Bzd7L1CXofBnT9bqV34vYo8O4JvuJmrJeWFjGpdcPk1m7dK29Sa1+Rupc8GStJmmw+AkGSGmfQS1LjDPolJDmQ5L4k9yaZXe16FpPkpiSPdvc1HGt7UZLPJfmP7ucLV7PG+SxQ9zVJDnXjfm+S31jNGueT5NwkX0jyrST7k7yna5+EMV+o9rEe9yTPT/LVJF/v6n5/174hyZ1J5pLc0l04MlYWqf1vkjzUN+avHvnvdo1+cUkOANNVNfY3YyT5NeAI8ImqekXX9iHgiaq6LskO4IVV9d7VrHPQAnVfAxypqg+vZm2LSfIS4CVVdU+SnwLuBt4G/D7jP+YL1f4OxnjcuzvuT6uqI0meB/wb8B7gKuDTVbUnyV8CX6+qj61mrYMWqf1K4J8GbyQdJWf0DamqL9K76qnfFmB3t72b3j/msbJA3WOvqh6pqnu67f8B7gfOZjLGfKHax1r1HOl2n9f9KeCN9B6/AuM75gvVvuwM+qUV8C9J7u4e1TBpXlxVj3Tb/w28eDWLOU7bk3yjW9oZu+WPfkmmgF8E7mTCxnygdhjzcU+yJsm9wKPA54DvAN+vqqNdl4OM6YfWYO1VdWzMP9iN+Q3pPQ14pAz6pb2+ql4DXAS8u1tmmEjdTWyTslb3MeDngFcDjwB/vqrVLCLJC4BPAX9UVT/oPzbuYz5P7WM/7lX1bFW9mt6d9puAX1jdioY3WHuSVwDvo/cefgl4ETDyZT6DfglVdaj7+SjwGXr/YU2S73XrscfWZR9d5XqGUlXf6/5R/Bj4K8Z03Lu11k8Bf1tVn+6aJ2LM56t9UsYdoKq+D3wB+BXgjO7xKzABj1rpq31zt4xW3VN+P84yjLlBv4gkp3UnqkhyGvBm4JuLv2rs9D+e4lLgH1exlqEdC8rO2xnDce9Oru0C7q+qj/QdGvsxX6j2cR/3JOuTnNFtn0rvezLupxeaW7tu4zrm89X+7b5JQeidWxj5mHvVzSKS/Cy9WTz0Hhfxd1X1wVUsaVFJ/h54A71Hn34P+DPgH4BbgZcB3wXeUVVjdeJzgbrfQG/5oIADwBV9695jIcnrgS8B9wE/7pr/lN5a97iP+UK1X8IYj3uSV9E72bqG3kT11qq6tvu3uofe0sfXgN/pZshjY5Ha7wDWAwHuBa7sO2k7mt9t0EtS21y6kaTGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcf8HQTe/0c4vUmYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "discrete_distplot(thinned_samples.copy());" ] }, { "cell_type": "markdown", "id": "fecec291", "metadata": {}, "source": [ "### References and related material \n", "\n", "1. [Wikipedia page on inverse transform sampling](https://en.wikipedia.org/wiki/Inverse_transform_sampling)\n", "2. [David Mackay's book on information theory](http://www.inference.org.uk/itprnn/book.pdf)\n", "3. [Composite models with underlying folded distributions](https://www.sciencedirect.com/science/article/pii/S0377042720306427)\n", "4. [Application of the generalized folded-normal distribution to the process capability measures](https://link.springer.com/article/10.1007/s00170-003-2043-x)\n", "4. [Pyro SVI tutorial part 3](https://pyro.ai/examples/svi_part_iii.html)\n", "5. [Approximation of the inverse Poisson cumulative distribution function](https://people.maths.ox.ac.uk/gilesm/codes/poissinv/paper.pdf)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }