{ "cells": [ { "cell_type": "markdown", "id": "849decd0-5b2c-4ec8-8f93-9b367aecce65", "metadata": {}, "source": [ "# Custom priors\n", "\n", "In this notebook, we demonstrate how to define custom parameter priors." ] }, { "cell_type": "raw", "id": "0b3b0120-dc05-4e5f-a6ed-21ef347ed0ea", "metadata": { "raw_mimetype": "text/restructuredtext", "tags": [] }, "source": [ "In this notebook:\n", "\n", "* :class:`RV `\n", "* :class:`DistributionBase `" ] }, { "cell_type": "code", "execution_count": 1, "id": "91fda0a7-b8a1-4f4a-8783-afc026213edf", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from pyabc import *\n", "\n", "rng = np.random.default_rng()" ] }, { "cell_type": "markdown", "id": "4e389904-a723-4003-b69b-43cc7731d48a", "metadata": {}, "source": [ "We consider a simple 2-dimensional test problem with four posterior modes:" ] }, { "cell_type": "code", "execution_count": 2, "id": "f5992100-8305-4328-822f-96378c2624c5", "metadata": {}, "outputs": [], "source": [ "# noise standard deviation\n", "std = 0.2\n", "\n", "\n", "def model(p):\n", " \"\"\"Quadratic model with two in- and outputs.\"\"\"\n", " return {\n", " \"y0\": p[\"p0\"] ** 2 + std * rng.normal(),\n", " \"y1\": p[\"p1\"] ** 2 + std * rng.normal(),\n", " }\n", "\n", "\n", "# ABC distance function\n", "distance = PNormDistance(p=2)\n", "\n", "# ground truth parameters\n", "gt_par = {\"p0\": 1, \"p1\": 1.2}\n", "\n", "# observed data\n", "obs = model(gt_par)\n", "\n", "# ABC population size and maximum evaluations\n", "pop_size = 1000\n", "max_total_sim = 50 * pop_size\n", "\n", "# parameter boundaries\n", "prior_bounds = {\"p0\": (-2, 2), \"p1\": (-2, 2)}" ] }, { "cell_type": "markdown", "id": "f21be6c4-d0c4-433c-b458-1a66c3612bf9", "metadata": {}, "source": [ "In most pyABC examples and applications, we have independent parameter priors, which can be expressed in pyABC simply via a `Distribution` which assumes independency of passed one-dimensional priors:" ] }, { "cell_type": "code", "execution_count": 3, "id": "1af01956-f2d4-4135-b93c-d3baca52667f", "metadata": {}, "outputs": [], "source": [ "prior = Distribution(\n", " p0=RV(\"uniform\", -2, 4),\n", " p1=RV(\"uniform\", -2, 4),\n", ")" ] }, { "cell_type": "markdown", "id": "9bbb74fd-8863-44bd-a4c4-d81d632f5180", "metadata": {}, "source": [ "If we use this prior, we find that the posterior exhibits four distinct modes:" ] }, { "cell_type": "code", "execution_count": 4, "id": "211dbea5-8fb4-4961-9232-172964af0072", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ABC.Sampler INFO: Parallelize sampling on 4 processes.\n", "ABC.History INFO: Start \n", "ABC INFO: Calibration sample t = -1.\n", "ABC INFO: t: 0, eps: 1.49234810e+00.\n", "ABC INFO: Accepted: 1000 / 2007 = 4.9826e-01, ESS: 1.0000e+03.\n", "ABC INFO: t: 1, eps: 1.10051564e+00.\n", "ABC INFO: Accepted: 1000 / 2618 = 3.8197e-01, ESS: 9.6474e+02.\n", "ABC INFO: t: 2, eps: 8.41417921e-01.\n", "ABC INFO: Accepted: 1000 / 3349 = 2.9860e-01, ESS: 9.5689e+02.\n", "ABC INFO: t: 3, eps: 6.29025180e-01.\n", "ABC INFO: Accepted: 1000 / 4853 = 2.0606e-01, ESS: 9.4498e+02.\n", "ABC INFO: t: 4, eps: 4.65501309e-01.\n", "ABC INFO: Accepted: 1000 / 7645 = 1.3080e-01, ESS: 9.7925e+02.\n", "ABC INFO: t: 5, eps: 3.31555988e-01.\n", "ABC INFO: Accepted: 1000 / 12925 = 7.7369e-02, ESS: 9.8705e+02.\n", "ABC INFO: t: 6, eps: 2.39407788e-01.\n", "ABC INFO: Accepted: 1000 / 24502 = 4.0813e-02, ESS: 9.8812e+02.\n", "ABC INFO: Stop: Total simulations budget.\n", "ABC.History INFO: Done \n" ] } ], "source": [ "# standard pyABC workflow\n", "abc = ABCSMC(model, prior, distance, population_size=pop_size)\n", "abc.new(create_sqlite_db_id(), obs)\n", "h = abc.run(max_total_nr_simulations=max_total_sim)" ] }, { "cell_type": "code", "execution_count": 5, "id": "d8e0debc-3c11-493d-a88f-cf07785587f4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ABC.Transition INFO: Best params: {'scaling': 0.05}\n", "ABC.Transition INFO: Best params: {'scaling': 0.05}\n", "ABC.Transition INFO: Best params: {'scaling': 0.05}\n" ] }, { "data": { "text/plain": [ "array([[, ],\n", " [,\n", " ]], dtype=object)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 1+2 dim marginal matrix plot via kernel density estimates\n", "visualization.plot_kde_matrix_highlevel(\n", " h,\n", " limits=prior_bounds,\n", " refval=gt_par,\n", " kde=GridSearchCV(),\n", ")" ] }, { "cell_type": "markdown", "id": "6798c74c-8511-4496-ac9c-b48ec60d690a", "metadata": {}, "source": [ "Now let us assume that we know that the sum of parameters should actually be greater than 1.0. We can encode this by formulating a custom prior derived from the `DistributionBase` class. We need to implement a `rvs()` method that generates (pseudo) random samples from the distribution (here via rejection sampling), and a `pdf()` method that evaluates the density for a given parameter. The density need not be normalized." ] }, { "cell_type": "code", "execution_count": 6, "id": "7c5152cc-a39a-4db2-b80b-7aa3d99ea96d", "metadata": {}, "outputs": [], "source": [ "class ConstrainedPrior(DistributionBase):\n", " def __init__(self):\n", " self.p0 = RV(\"uniform\", -2, 4)\n", " self.p1 = RV(\"uniform\", -2, 4)\n", " self.min_sum: float = 1.0\n", "\n", " def rvs(self, *args, **kwargs):\n", " while True:\n", " p0, p1 = self.p0.rvs(), self.p1.rvs()\n", " if p0 + p1 > self.min_sum:\n", " return Parameter(p0=p0, p1=p1)\n", "\n", " def pdf(self, x):\n", " p0, p1 = x[\"p0\"], x[\"p1\"]\n", " if p0 + p1 <= self.min_sum:\n", " return 0.0\n", " return self.p0.pdf(p0) * self.p1.pdf(p1)\n", "\n", "\n", "constrained_prior = ConstrainedPrior()" ] }, { "cell_type": "markdown", "id": "4718a699-57c8-4ca7-aa94-ee3fd6b00976", "metadata": {}, "source": [ "As expected, now only the upper right posterior mode is sampled." ] }, { "cell_type": "code", "execution_count": 7, "id": "6014cd2f-98f3-4474-97ac-c83c9e8bd299", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ABC.Sampler INFO: Parallelize sampling on 4 processes.\n", "ABC.History INFO: Start \n", "ABC INFO: Calibration sample t = -1.\n", "ABC INFO: t: 0, eps: 1.57795072e+00.\n", "ABC INFO: Accepted: 1000 / 1928 = 5.1867e-01, ESS: 1.0000e+03.\n", "ABC INFO: t: 1, eps: 1.14261527e+00.\n", "ABC INFO: Accepted: 1000 / 2073 = 4.8239e-01, ESS: 9.6627e+02.\n", "ABC INFO: t: 2, eps: 8.62637704e-01.\n", "ABC INFO: Accepted: 1000 / 2437 = 4.1034e-01, ESS: 9.4258e+02.\n", "ABC INFO: t: 3, eps: 6.21836300e-01.\n", "ABC INFO: Accepted: 1000 / 2536 = 3.9432e-01, ESS: 8.8251e+02.\n", "ABC INFO: t: 4, eps: 4.60224825e-01.\n", "ABC INFO: Accepted: 1000 / 2651 = 3.7722e-01, ESS: 5.7476e+02.\n", "ABC INFO: t: 5, eps: 3.33140306e-01.\n", "ABC INFO: Accepted: 1000 / 3535 = 2.8289e-01, ESS: 7.6206e+02.\n", "ABC INFO: t: 6, eps: 2.35619085e-01.\n", "ABC INFO: Accepted: 1000 / 4980 = 2.0080e-01, ESS: 6.2725e+02.\n", "ABC INFO: t: 7, eps: 1.70215801e-01.\n", "ABC INFO: Accepted: 1000 / 7465 = 1.3396e-01, ESS: 5.8647e+02.\n", "ABC INFO: t: 8, eps: 1.21842793e-01.\n", "ABC INFO: Accepted: 1000 / 13116 = 7.6243e-02, ESS: 5.4529e+02.\n", "ABC INFO: t: 9, eps: 8.46962594e-02.\n", "ABC INFO: Accepted: 1000 / 25699 = 3.8912e-02, ESS: 3.8159e+02.\n", "ABC INFO: Stop: Total simulations budget.\n", "ABC.History INFO: Done \n" ] } ], "source": [ "# standard pyABC workflow\n", "abc = ABCSMC(model, constrained_prior, distance, population_size=pop_size)\n", "abc.new(create_sqlite_db_id(), obs)\n", "h = abc.run(max_total_nr_simulations=max_total_sim)" ] }, { "cell_type": "code", "execution_count": 8, "id": "dd0cf6b1-8a87-42a5-b39a-7faccb62bbe3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ABC.Transition INFO: Best params: {'scaling': 1.0}\n", "ABC.Transition INFO: Best params: {'scaling': 0.2875}\n", "ABC.Transition INFO: Best params: {'scaling': 1.0}\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 1+2 dim marginal matrix plot via kernel density estimates\n", "arr_ax = visualization.plot_kde_matrix_highlevel(\n", " h,\n", " limits=prior_bounds,\n", " refval=gt_par,\n", " kde=GridSearchCV(),\n", ")\n", "arr_ax[0][1].plot([-2, 2], [3, -1], linestyle=\"dotted\", color=\"grey\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }