{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting started with PyMC4: Bayesian neural networks\n", "\n", "This article demonstrates how to implement a simple Bayesian neural network for regression with an early [PyMC4 development snapshot](https://github.com/pymc-devs/pymc4/tree/1c5e23825271fc2ff0c701b9224573212f56a534) (from Jul 29, 2020). It can be installed with \n", "\n", "```bash\n", "pip install git+https://github.com/pymc-devs/pymc4@1c5e23825271fc2ff0c701b9224573212f56a534\n", "```\n", "\n", "I'll update this article from time to time to cover new features or to fix breaking API changes. The latest update (Aug. 19, 2020) includes the recently added support for variational inference (VI). The following sections assume that you have a basic familiarity with [PyMC3](https://docs.pymc.io/). If this is not the case I recommend reading [Getting started with PyMC3](https://docs.pymc.io/notebooks/getting_started.html) first." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.0a2\n", "2.4.0-dev20200818\n", "0.12.0-dev20200818\n" ] } ], "source": [ "import logging\n", "import pymc4 as pm\n", "import numpy as np\n", "import arviz as az\n", "\n", "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "\n", "print(pm.__version__)\n", "print(tf.__version__)\n", "print(tfp.__version__)\n", "\n", "# Mute Tensorflow warnings ...\n", "logging.getLogger('tensorflow').setLevel(logging.ERROR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction to PyMC4\n", "\n", "PyMC4 uses [Tensorflow Probability](https://www.tensorflow.org/probability) (TFP) as backend and PyMC4 random variables are wrappers around TFP distributions. Models must be defined as [generator](https://docs.python.org/3/glossary.html#term-generator) functions, using a `yield` keyword for each random variable. PyMC4 uses [coroutines](https://www.python.org/dev/peps/pep-0342/) to interact with the generator to get access to these variables. Depending on the context, PyMC4 may sample values from random variables, compute log probabilities of observed values, ... and so on. Details are covered in the [PyMC4 design guide](https://github.com/pymc-devs/pymc4/blob/master/notebooks/pymc4_design_guide.ipynb). Model generator functions must be decorated with ` @pm.model` as shown in the following trivial example:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "@pm.model\n", "def model(x):\n", " # prior for the mean of a normal distribution\n", " loc = yield pm.Normal('loc', loc=0, scale=10)\n", " \n", " # likelihood of observed data\n", " obs = yield pm.Normal('obs', loc=loc, scale=1, observed=x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This models normally distributed data centered at a location `loc` to be inferred. Inference can be started with `pm.sample()` which uses the [No-U-Turn Sampler](https://jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf) (NUTS). Samplers other than NUTS are currently not implemented in PyMC4." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
<xarray.Dataset>\n", "Dimensions: (chain: 10, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3 4 5 6 7 8 9\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", "Data variables:\n", " model/loc (chain, draw) float32 3.1448023 3.1448023 ... 3.004984 3.4603796\n", "Attributes:\n", " created_at: 2020-08-19T12:02:26.831730\n", " arviz_version: 0.9.0
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[3.1448023, 3.1448023, 3.1593454, ..., 3.1752453, 3.5500278,\n", " 3.3998082],\n", " [3.096355 , 2.9651537, 3.107913 , ..., 2.7815466, 3.302857 ,\n", " 3.3626337],\n", " [3.0079126, 3.2646103, 3.2475855, ..., 3.3154554, 3.2643337,\n", " 3.132182 ],\n", " ...,\n", " [3.213623 , 3.1980019, 3.1980019, ..., 2.8879852, 3.052567 ,\n", " 3.4903805],\n", " [3.198672 , 3.4714966, 2.9759057, ..., 3.3256674, 3.007779 ,\n", " 3.007779 ],\n", " [2.835832 , 2.8538866, 2.8538866, ..., 3.075596 , 3.004984 ,\n", " 3.4603796]], dtype=float32)
<xarray.Dataset>\n", "Dimensions: (chain: 10, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3 4 5 6 7 8 9\n", " * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999\n", "Data variables:\n", " lp (chain, draw) float32 -43.887486 -43.887486 ... -45.533813\n", " tree_size (chain, draw) int32 3 1 1 1 1 1 1 1 1 ... 3 1 3 1 1 3 1 3\n", " diverging (chain, draw) bool False False False ... False False False\n", " energy (chain, draw) float32 -44.09497 -44.962246 ... -47.662434\n", " mean_tree_accept (chain, draw) float32 -0.11299977 ... -1.8010322\n", "Attributes:\n", " created_at: 2020-08-19T12:02:26.832872\n", " arviz_version: 0.9.0
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[-43.887486, -43.887486, -43.897667, ..., -43.916054, -46.54659 ,\n", " -44.98605 ],\n", " [-43.899372, -44.28523 , -43.890137, ..., -45.69251 , -44.33845 ,\n", " -44.7044 ],\n", " [-44.10273 , -44.16057 , -44.095512, ..., -44.40666 , -44.159447,\n", " -43.883797],\n", " ...,\n", " [-43.9917 , -43.955574, -43.955574, ..., -44.75345 , -43.970715,\n", " -45.84589 ],\n", " [-43.956974, -45.6463 , -44.234173, ..., -44.465443, -44.103218,\n", " -44.103218],\n", " [-45.17109 , -45.017273, -45.017273, ..., -43.926025, -44.113483,\n", " -45.533813]], dtype=float32)
array([[3, 1, 1, ..., 1, 1, 1],\n", " [1, 3, 1, ..., 3, 3, 3],\n", " [3, 3, 1, ..., 1, 1, 1],\n", " ...,\n", " [1, 1, 1, ..., 1, 1, 3],\n", " [3, 1, 3, ..., 3, 3, 3],\n", " [1, 1, 3, ..., 3, 1, 3]], dtype=int32)
array([[False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " ...,\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False],\n", " [False, False, False, ..., False, False, False]])
array([[-44.09497 , -44.962246, -43.90053 , ..., -47.178867, -46.633278,\n", " -46.38453 ],\n", " [-44.1807 , -44.30795 , -44.069817, ..., -45.740902, -44.99158 ,\n", " -44.736084],\n", " [-44.36696 , -44.163994, -44.251163, ..., -44.76689 , -44.440952,\n", " -44.003162],\n", " ...,\n", " [-44.87733 , -44.015293, -45.389606, ..., -44.754826, -44.384026,\n", " -45.847645],\n", " [-44.192295, -45.749332, -46.55621 , ..., -47.500885, -44.505745,\n", " -44.819065],\n", " [-46.209038, -45.75422 , -47.55198 , ..., -44.166405, -44.148354,\n", " -47.662434]], dtype=float32)
array([[-0.11299977, -1.4377098 , -0.00593183, ..., -4.5058823 ,\n", " -1.5323563 , 0. ],\n", " [ 0. , -0.17513256, 0. , ..., -0.24506885,\n", " -0.03921076, -0.09310422],\n", " [-0.01966013, -0.0111048 , 0. , ..., 0. ,\n", " 0. , 0. ],\n", " ...,\n", " [ 0. , 0. , -1.9999924 , ..., -0.50117874,\n", " 0. , -0.5668998 ],\n", " [ 0. , -0.9840775 , -0.8787381 , ..., -0.9555875 ,\n", " -0.13321258, -0.63191855],\n", " [ 0. , 0. , -1.033627 , ..., -0.13674697,\n", " -0.10919955, -1.8010322 ]], dtype=float32)
<xarray.Dataset>\n", "Dimensions: (model/obs_dim_0: 30)\n", "Coordinates:\n", " * model/obs_dim_0 (model/obs_dim_0) int64 0 1 2 3 4 5 6 ... 24 25 26 27 28 29\n", "Data variables:\n", " model/obs (model/obs_dim_0) float64 5.034 3.825 3.969 ... 3.918 2.585\n", "Attributes:\n", " created_at: 2020-08-19T12:02:26.834054\n", " arviz_version: 0.9.0
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", " 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
array([5.03434655, 3.82462505, 3.96940926, 2.3396569 , 3.5108133 ,\n", " 3.0514346 , 2.10080958, 1.43897365, 3.36524396, 2.54937186,\n", " 3.00850969, 2.11411964, 3.51740442, 3.3991824 , 3.80398466,\n", " 4.44110725, 1.68463182, 3.92476337, 2.57933469, 4.36729299,\n", " 1.57571855, 4.29665483, 3.78128164, 2.11619165, 2.56151748,\n", " 2.68069589, 2.30618091, 4.04778368, 3.91752526, 2.58532026])