{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6umP1IKf4Dg6" }, "source": [ "# Autobatching log-densities example\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", "\n", "Inspired by a notebook by @davmre." ] }, { "metadata": { "colab_type": "code", "id": "PaW85yP_BrCF", "colab": {} }, "cell_type": "code", "source": [ "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.11-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade -q jax" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "8RZDkfbV3zdR" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import functools\n", "import itertools\n", "import re\n", "import sys\n", "import time\n", "\n", "from matplotlib.pyplot import *\n", "\n", "import jax\n", "\n", "from jax import lax\n", "from jax import numpy as np\n", "from jax import scipy\n", "from jax import random\n", "\n", "import numpy as onp\n", "import scipy as oscipy" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "p2VcZS1d34C6" }, "source": [ "# Generate a fake binary classification dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "pq41hMvn4c_i" }, "outputs": [], "source": [ "onp.random.seed(10009)\n", "\n", "num_features = 10\n", "num_points = 100\n", "\n", "true_beta = onp.random.randn(num_features).astype(np.float32)\n", "all_x = onp.random.randn(num_points, num_features).astype(np.float32)\n", "y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "height": 102 }, "colab_type": "code", "executionInfo": { "elapsed": 30, "status": "ok", "timestamp": 1549999404494, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "O0nVumAw7IlT", "outputId": "c474098f-4e81-4fc8-ad8f-3ba825409be3" }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,\n", " 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n", " 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,\n", " 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,\n", " 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DZRVvhpn5aB1" }, "source": [ "# Write the log-joint function for the model\n", "\n", "We'll write a non-batched version, a manually batched version, and an autobatched version." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "C_mDXInL7nsP" }, "source": [ "## Non-batched" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "ZHyL2sJh5ajG" }, "outputs": [], "source": [ "def log_joint(beta):\n", " result = 0.\n", " # Note that no `axis` parameter is provided to `np.sum`.\n", " result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))\n", " result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n", " return result" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 3383, "status": "ok", "timestamp": 1549999409301, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "e51qW0ro6J7C", "outputId": "c778d4fc-85b9-4fea-9875-0d0a3397a027" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU.\n", " warnings.warn('No GPU found, falling back to CPU.')\n" ] }, { "data": { "text/plain": [ "array(-213.23558, dtype=float32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_joint(onp.random.randn(num_features))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "height": 895 }, "colab_type": "code", "executionInfo": { "elapsed": 4130, "status": "error", "timestamp": 1549999413496, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "fglQXK1Y6wnm", "outputId": "cf85d9a7-b403-4e75-efb6-b9d057e66f3c" }, "outputs": [ { "ename": "ValueError", "evalue": "Incompatible shapes for broadcasting: ((100, 10), (1, 100))", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mbatched_test_beta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mlog_joint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mlog_joint\u001b[0;34m(beta)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Note that no `axis` parameter is provided to `np.sum`.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbeta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc\u001b[0m in \u001b[0;36m\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlax_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0m_promote_args_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlax_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0m_promote_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_fn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_wraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc\u001b[0m in \u001b[0;36m_promote_args\u001b[0;34m(fun_name, *args)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;34m\"\"\"Convenience function to apply Numpy argument shape and dtype promotion.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0m_check_arraylike\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_promote_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0m_promote_dtypes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc\u001b[0m in \u001b[0;36m_promote_shapes\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0mshapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0mnd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbroadcast_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)\n\u001b[1;32m 141\u001b[0m if len(shp) != nd else arg for arg, shp in zip(args, shapes)]\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/util.pyc\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpopitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmemoized_fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lax.pyc\u001b[0m in \u001b[0;36mbroadcast_shapes\u001b[0;34m(*shapes)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mresult_shape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m raise ValueError(\"Incompatible shapes for broadcasting: {}\"\n\u001b[0;32m---> 69\u001b[0;31m .format(tuple(map(tuple, shapes))))\n\u001b[0m\u001b[1;32m 70\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: Incompatible shapes for broadcasting: ((100, 10), (1, 100))" ] } ], "source": [ "# This doesn't work, because we didn't write `log_prob()` to handle batching.\n", "batch_size = 10\n", "batched_test_beta = onp.random.randn(batch_size, num_features)\n", "\n", "log_joint(onp.random.randn(batch_size, num_features))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_lQ8MnKq7sLU" }, "source": [ "## Manually batched" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "2g5-4bQE7gRA" }, "outputs": [], "source": [ "def batched_log_joint(beta):\n", " result = 0.\n", " # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis\n", " # or setting it incorrectly yields an error; at worst, it silently changes the\n", " # semantics of the model.\n", " result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),\n", " axis=-1)\n", " # Note the multiple transposes. Getting this right is not rocket science,\n", " # but it's also not totally mindless. (I didn't get it right on the first\n", " # try.)\n", " result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),\n", " axis=-1)\n", " return result" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "height": 68 }, "colab_type": "code", "executionInfo": { "elapsed": 735, "status": "ok", "timestamp": 1549999417264, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "KdDMr-Gy85CO", "outputId": "1e90fc29-60fb-4460-f08f-2dd486cc8f5e" }, "outputs": [ { "data": { "text/plain": [ "array([-147.84033 , -207.02205 , -109.26075 , -243.8083 , -163.02911 ,\n", " -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ],\n", " dtype=float32)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 10\n", "batched_test_beta = onp.random.randn(batch_size, num_features)\n", "\n", "batched_log_joint(batched_test_beta)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-uuGlHQ_85kd" }, "source": [ "## Autobatched with vmap\n", "\n", "It just works." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "height": 68 }, "colab_type": "code", "executionInfo": { "elapsed": 174, "status": "ok", "timestamp": 1549999417694, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "SU20bouH8-Za", "outputId": "5637b58a-0d7e-4a61-b74a-f4d2cab2105a" }, "outputs": [ { "data": { "text/plain": [ "array([-147.84033 , -207.02205 , -109.26075 , -243.8083 , -163.02911 ,\n", " -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ],\n", " dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vmap_batched_log_joint = jax.vmap(log_joint)\n", "vmap_batched_log_joint(batched_test_beta)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "L1KNBo9y_yZJ" }, "source": [ "# Self-contained variational inference example\n", "\n", "A little code is copied from above." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "lQTPaaQMJh8Y" }, "source": [ "## Set up the (batched) log-joint function" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "AITXbaofA3Pm" }, "outputs": [], "source": [ "@jax.jit\n", "def log_joint(beta):\n", " result = 0.\n", " # Note that no `axis` parameter is provided to `np.sum`.\n", " result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))\n", " result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n", " return result\n", "\n", "batched_log_joint = jax.jit(jax.vmap(log_joint))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UmmFMQ8LJk6a" }, "source": [ "## Define the ELBO and its gradient" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": {}, "colab_type": "code", "id": "MJtnskL6BKwV" }, "outputs": [], "source": [ "def elbo(beta_loc, beta_log_scale, epsilon):\n", " beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon\n", " return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))\n", "elbo = jax.jit(elbo, static_argnums=(2, 3))\n", "elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oQC7xKYnJrp5" }, "source": [ "## Optimize the ELBO using SGD" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "height": 1717 }, "colab_type": "code", "executionInfo": { "elapsed": 2986, "status": "ok", "timestamp": 1549999510348, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "9JrD5nNgH715", "outputId": "1b7949cc-1296-46bb-9d88-412475834944" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\t-180.853881836\n", "10\t-113.060455322\n", "20\t-102.737258911\n", "30\t-99.7873535156\n", "40\t-98.9089889526\n", "50\t-98.297454834\n", "60\t-98.1863174438\n", "70\t-97.5797195435\n", "80\t-97.2860031128\n", "90\t-97.4699630737\n", "100\t-97.4771728516\n", "110\t-97.5806732178\n", "120\t-97.494354248\n", "130\t-97.5027313232\n", "140\t-96.8639526367\n", "150\t-97.4419784546\n", "160\t-97.0694046021\n", "170\t-96.8402862549\n", "180\t-97.2133789062\n", "190\t-97.5650253296\n", "200\t-97.2639770508\n", "210\t-97.1197967529\n", "220\t-97.395942688\n", "230\t-97.1683197021\n", "240\t-97.1184082031\n", "250\t-97.2434539795\n", "260\t-97.2978668213\n", "270\t-96.692855835\n", "280\t-96.9643859863\n", "290\t-97.3005523682\n", "300\t-96.6359176636\n", "310\t-97.0351867676\n", "320\t-97.529083252\n", "330\t-97.2881164551\n", "340\t-97.0732192993\n", "350\t-97.1561889648\n", "360\t-97.2588195801\n", "370\t-97.1951446533\n", "380\t-97.1309204102\n", "390\t-97.1172637939\n", "400\t-96.9387359619\n", "410\t-97.2667694092\n", "420\t-97.353225708\n", "430\t-97.2100753784\n", "440\t-97.2843475342\n", "450\t-97.1630859375\n", "460\t-97.2612457275\n", "470\t-97.2134399414\n", "480\t-97.2399749756\n", "490\t-97.1491317749\n", "500\t-97.2352828979\n", "510\t-96.9342041016\n", "520\t-97.212097168\n", "530\t-96.8257751465\n", "540\t-97.0128479004\n", "550\t-96.9417648315\n", "560\t-97.1652069092\n", "570\t-97.2916564941\n", "580\t-97.429397583\n", "590\t-97.2437133789\n", "600\t-97.1521911621\n", "610\t-97.4984436035\n", "620\t-96.9906997681\n", "630\t-96.8895645142\n", "640\t-96.8996887207\n", "650\t-97.1379394531\n", "660\t-97.4370574951\n", "670\t-96.9923629761\n", "680\t-97.1562423706\n", "690\t-97.1869049072\n", "700\t-97.1116027832\n", "710\t-97.7810516357\n", "720\t-97.2322616577\n", "730\t-97.1620635986\n", "740\t-96.9958190918\n", "750\t-96.6672210693\n", "760\t-97.1679534912\n", "770\t-97.5143508911\n", "780\t-97.2890090942\n", "790\t-96.9122619629\n", "800\t-97.1709976196\n", "810\t-97.290473938\n", "820\t-97.1624298096\n", "830\t-97.1910629272\n", "840\t-97.5638198853\n", "850\t-97.0019378662\n", "860\t-96.8655548096\n", "870\t-96.7633743286\n", "880\t-96.8366088867\n", "890\t-97.1217956543\n", "900\t-97.0955505371\n", "910\t-97.0682373047\n", "920\t-97.1194763184\n", "930\t-96.8792953491\n", "940\t-97.4562530518\n", "950\t-96.6928024292\n", "960\t-97.293762207\n", "970\t-97.3353042603\n", "980\t-97.349609375\n", "990\t-97.0967559814\n" ] } ], "source": [ "def normal_sample(key, shape):\n", " \"\"\"Convenience function for quasi-stateful RNG.\"\"\"\n", " new_key, sub_key = random.split(key)\n", " return new_key, random.normal(sub_key, shape)\n", "normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n", "\n", "key = random.PRNGKey(10003)\n", "\n", "beta_loc = np.zeros(num_features, np.float32)\n", "beta_log_scale = np.zeros(num_features, np.float32)\n", "\n", "step_size = 0.01\n", "batch_size = 128\n", "epsilon_shape = (batch_size, num_features)\n", "for i in range(1000):\n", " key, epsilon = normal_sample(key, epsilon_shape)\n", " elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(\n", " beta_loc, beta_log_scale, epsilon)\n", " beta_loc += step_size * beta_loc_grad\n", " beta_log_scale += step_size * beta_log_scale_grad\n", " if i % 10 == 0:\n", " print('{}\\t{}'.format(i, elbo_val))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "b3ZAe5fJJ2KM" }, "source": [ "## Display the results\n", "\n", "Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "height": 481 }, "colab_type": "code", "executionInfo": { "elapsed": 263, "status": "ok", "timestamp": 1549999510632, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "zt1NBLoVHtOG", "outputId": "2f0081cf-bbfe-426c-bc5e-a1c09468234a" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbwAAAGtCAYAAABtOsHhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzs3Xd8jWfjBvDrzkCtVlP6Umq9NRNSUk0QYuQkETKEGKF2RbWqSGsVpa2iqNqrVgkVSmNlSWQ0VoxSvCgxqyU0DULGuX9/HMnPSOJEzjnPGdf388nnJGc8z5UTcuV+1i2klCAiIjJ3VkoHICIiMgQWHhERWQQWHhERWQQWHhERWQQWHhERWQQWHhERWQQWHhERWQQWHhERWQQWHhERWQQbpQMUx2uvvSZr1aqldAwiIjIiKSkpt6SUlZ/3PJMqvFq1auHw4cNKxyAiIiMihLikzfO4SZOIiCwCC4+IiCwCC4+IiCyCSe3DK0h2djauXr2KBw8eKB2FyOKUKVMG1atXh62trdJRiJ7L5Avv6tWrqFChAmrVqgUhhNJxiCyGlBJpaWm4evUqateurXQcoucy+U2aDx48gJ2dHcuOyMCEELCzs+PWFTIZJl94AFh2RArh/z0yJWZReERERM/DwtORbdu2QQiBM2fOKJpj0qRJiI6OLvFy/vnnHyxatKjYr5syZQq+/fbbAu9/44034OjoCHt7e/zyyy/FXvaxY8ewa9euYr/u+vXr6NatW7Ff9zghBPr06ZP/dU5ODipXrozOnTuXaLlEZDgsPB0JDQ1F69atERoaqrNl5uTkFPs1U6dORceOHUu87hctvKJ88sknOHbsGDZv3oyBAwdCrVYX6/UvUng5OTmoVq0awsLCivWap5UrVw4nT55EZmYmACAqKgpvvPFGsbIQkbIssvBSLt3BwtjzSLl0RyfLu3v3LhITE7Fy5Ups3Lgx//64uDi0adMG3t7eqF+/PoKDg/N/yZcvXx6ffPIJGjdujA4dOuDmzZsAADc3N4wcORJOTk6YN28eUlNT0b59ezRp0gQdOnTA5cuXAQC+vr5Yu3YtAGDp0qUICgoCAPTv3z//l3utWrUwbtw4ODo6wsnJCUeOHIGHhwfq1q2LJUuW5Gfv0KEDmjVrBgcHB2zfvh0AMHbsWPzxxx9wdHRESEgIAGDWrFl455130KRJE0yePDn/+/zqq69Qr149tG7dGv/73/+e+341bNgQNjY2uHXrVqHf3+bNm2Fvb4+mTZuiTZs2yMrKwqRJk7Bp0yY4Ojpi06ZNuHfvHgYOHIgWLVrg7bffzs++evVq+Pj4oH379ujQoQNSU1Nhb28PQHOQ04ABA+Dg4IC3334bsbGxBb6mIJ06dcLOnTsBaP7A6dWrV/5jhWVJTU2Fq6srmjVrhmbNmuHXX3/N/7fh5uaGbt26oUGDBggKCoKUMv+9b9SoEZo0aYIxY8Y89/0kIi1JKU3mo3nz5vJpp06deua+ohxOvS3rT9wla4/dIetP3CUPp94u1usL8uOPP8qBAwdKKaV0cXGRhw8fllJKGRsbK0uXLi3/+OMPmZOTIzt27Cg3b94spZQSgPzxxx+llFJ+8cUXcvjw4VJKKdu2bSuHDRuWv+zOnTvL1atXSymlXLlypfT19ZVSSnnjxg1Zt25dGR8fL9966y2ZlpYmpZSyX79++euoWbOmXLRokZRSypEjR0oHBwf577//yr///ltWqVJFSilldna2TE9Pl1JKefPmTVm3bl2pVqvlxYsXZePGjfNzREREyCFDhki1Wi1zc3Olt7e33Ldvnzx8+LC0t7eX9+7dk+np6bJu3bpy1qxZz7xHkydPzr9///79smrVqlKtVhf6/dnb28urV69KKaW8c+eOlFLKVatW5b9PUko5btw4uW7duvznvPXWW/Lu3bty1apV8o033sh/Tx7/Xr799ls5YMAAKaWUp0+fljVq1JCZmZnPvOZp5cqVk8ePH5cBAQEyMzNTNm3aVMbGxkpvb+8is9y7d09mZmZKKaU8e/aszPs3HBsbKytWrCivXLkic3NzpbOzs0xISJC3bt2S9erVk2q1+onv3ZgV9/8gka4BOCy16BCLG+Htv5CGrBw11BLIzlFj/4W0Ei8zNDQUPXv2BAD07Nnzic2aLVq0QJ06dWBtbY1evXohMTERAGBlZYUePXoAAPr06ZN/P4D8+wEgOTkZvXv3BgD07ds3/3mvv/46pk6dinbt2mH27Nl49dVXC8zm4+MDAHBwcMC7776LChUqoHLlyihdujT++ecfSCkxfvx4NGnSBB07dsS1a9fw119/PbOcyMhIREZG4u2330azZs1w5swZnDt3DgkJCfD390fZsmVRsWLF/PUVZO7cuXB0dMSYMWOwadMmCCEK/f5atWqF/v37Y/ny5cjNzS1weZGRkfjmm2/g6OgINzc3PHjwIH+E6O7uXuB7kpiYmL8vrkGDBqhZsybOnj1b5GvyNGnSBKmpqQgNDUWnTp20ypKdnY0hQ4bAwcEB3bt3x6lTp/Jf06JFC1SvXh1WVlZwdHREamoqXn75ZZQpUwaDBg3C1q1bUbZs2ULzEFHxmPyJ58XlXMcOpWyskJ2jhq2NFZzr2JVoebdv38bevXtx4sQJCCGQm5sLIQRmzZoF4NnDtgs7jPvx+8uVK6fVuk+cOAE7Oztcv3690OeULl0agKZg8z7P+zonJwfr16/HzZs3kZKSAltbW9SqVavA86qklBg3bhyGDh36xP3fffedVlkBzT48bTfRLVmyBAcOHMDOnTvRvHlzpKSkFJhpy5YtqF+//hP3HzhwQOv38HHavMbHxwdjxoxBXFwc0tL+/4+lwrJMmTIFr7/+Oo4fPw61Wo0yZcrkP/b4z8Pa2ho5OTmwsbHBwYMHERMTg7CwMCxYsAB79+4t9vdCRM+yuBFe85qVsH6wM0ap6mP9YGc0r1mpRMsLCwtD3759cenSJaSmpuLKlSuoXbs2EhISAAAHDx7ExYsXoVarsWnTJrRu3RoAoFar8/e1bdiwIf/+p7Vs2TJ/v+D69evh6uqav9zdu3fj6NGj+Pbbb3Hx4sUXyp+eno4qVarA1tYWsbGxuHRJM8tGhQoVkJGRkf88Dw8P/PDDD7h79y4A4Nq1a/j777/Rpk0bbNu2DZmZmcjIyEB4eHix1l/Y9/fHH3/g3XffxdSpU1G5cmVcuXKlwEzz58/P3/d19OjR567P1dUV69evBwCcPXsWly9ffqakijJw4EBMnjwZDg4OT9xfWJb09HRUrVoVVlZWWLduXaGj1Tx3795Feno6OnXqhLlz5+L48eNaZyMFJScD06drbsloWVzhAZrSG97uvyUuO0CzOdPf3/+J+wICAvI3a77zzjv48MMP0bBhQ9SuXTv/ueXKlcPBgwdhb2+PvXv3YtKkSQUuf/78+Vi1ahWaNGmCdevWYd68eXj48CGGDBmCH374AdWqVcPs2bMxcODA/F+2xREUFITDhw/DwcEBa9euRYMGDQAAdnZ2aNWqFezt7RESEgKVSoXevXvDxcUFDg4O6NatGzIyMtCsWTP06NEDTZs2hZeXF955551irb+g7w8AQkJC4ODgAHt7e7Rs2RJNmzZFu3btcOrUqfyDVj7//HNkZ2ejSZMmaNy4MT7//PPnru+DDz6AWq2Gg4MDevTogdWrVz8x0nqe6tWrY8SIEc/cX1iWDz74AGvWrEHTpk1x5syZ544iMzIy0LlzZzRp0gStW7fGnDlztM5GCklOBjp0AD7/XHPL0jNa4kV+SSrFyclJPj0B7OnTp9GwYUOFEhUtLi4O3377LXbs2PHMY+XLl88fLRGZMmP+P2gQ06dryi43F7C2BqZNA8aNUzqVRRFCpEgpnZ73PIsc4RER6YybG1CqlKbsSpXSfE1GyeIOWjEkNzc3uBXyj5+jOyIz4eICxMQAcXGasnNxUToRFYKFR0RUUi4uLDoTwE2aRERkEVh4RERkEVh4RERkEVh4RERkEVh4RERkEVh4OmQsk8C2bNlSJ8vR9SSw1tbW+RPAdu/eHffv3zdYJqBk78uVK1fQrl07NGrUCI0bN86/Ikxx5H3/eR/ffPPNC+cpzvqaNm36xNRERJaKhadDup4EVkpZ7ElSAejsF5uuJ4F96aWXcOzYMZw8eRKlSpXKn5NP35ny3sfivC9Pv/c2NjaYPXs2Tp06hf3792PhwoVPzHygjbzvP+9j7NixRa6zOD//gp6bt77jx49j+vTpGFeMq3+86L89ImNmmYWnhwu9FjQJbGpqav7kng0bNkS3bt3yRzWFPZaamor69evjvffeg729Pa5cuYI5c+bA3t4e9vb2+bMTHDp0CE2aNMGDBw9w7949NG7cGCdPngSguWzZ4+vo378/6tWrh6CgIERHR6NVq1Z46623cPDgwfz8fn5+aN68ORo3boxly5YBKHgS2B9//BEtWrSAo6Mjhg4dmn8x5OJOAuvq6orz588DQIHfH6CZVNXb2xtNmzaFvb09Nm3apHWmgt7HvPelsHUW9Jo8VatWRbNmzQBoLqzdsGFDXLt2Lf/x48ePo02bNmjUqBGsrKwghCj0+qiPe3qdCQkJBWYobt6n/fvvv6hU6f+vHVvQz7ug5T39/hOZNG0mzTOWD11MACt//VXKl16S0tpac/vrr8V7fSEKmgT24sWLEoBMTEyUUko5YMCA/ElQC3vs4sWLUgghk5OTpZQyf4LVu3fvyoyMDNmoUSN55MgRKaWUEyZMkKNHj5YffPCB/Prrr/OzlCtXLn8d1tbW8rfffpO5ubmyWbNmcsCAAVKtVstt27blT7Yqpcyf+PT+/fuycePG8tatW89MAnvq1CnZuXNnmZWVJaWUctiwYXLNmjVaTwKblys7O1v6+PjIRYsWFfn9hYWFycGDB+e//p9//tE609Pv4+PrL2ydBb2mIBcvXpQ1atTInzg3MzNT1q9fXx44cEBKKeXEiRPlmDFj8idxzWNlZSWbNm2a/7Fx48Zn1llQhhfNm7e++vXry4oVK+ZPTCxl4T/vx5dX0PtfEE4AS0oDJ4AtRFwckJWludBrVpbmax0obBLYGjVqoFWrVgCenei1sMdq1qwJZ2dnAJoJS/39/VGuXDmUL18eXbt2zZ96aNKkSYiKisLhw4fx6aefFpirdu3acHBwgJWVFRo3bowOHTpACAEHBwekpqbmP+/7779H06ZN4ezsjCtXruDcuXPPLCsmJgYpKSl455134OjoiJiYGFy4cEHrSWAzMzPh6OgIJycnvPnmmxg0aFCR35+DgwOioqLw2WefISEhAS+//LLWmZ5+Hx9X1DoLe02eu3fvIiAgAN999x0qVqwIAIiOjkazZs3QokULAJqJYm/fvv3M3IdPb9LMm+j36XU+/fWL5s1b35kzZ7Bnzx689957+TNqFPbzfnx52rz/RKbE8i4tlneh16wsnV3otbBJYIcPH17kBLCFPabt5KVpaWm4e/cusrOz8eDBgwJf9/Skr49PCJuTkwNAM6tDdHQ0kpOTUbZs2fwZu58mpUS/fv0wffr0J+7XdhLYvF/A2qpXrx6OHDmCXbt2YeLEiejQoQPee+89rTKlpqbqfBLY7OxsBAQEICgoCF27ds2//+TJk0/Mj3fkyJH8zZ8vss7i5Nb2uS4uLrh16xZu3ryJU6dOFfrzfnx5Bb3/2mymJTJWio3whBBlhBAHhRDHhRC/CyG+MMiK8y70Om2a5lYH178rbBLYK1eu4PLly0h+tK/w6Ylei3osj6urK7Zt24b79+/j3r17+Pnnn/MnSR06dCimTZuGoKAgfPbZZy+cPz09HZUqVULZsmVx5swZ7N+/H8Czk8B26NABYWFh+PvvvwFoiv7SpUslmgS2qO/v+vXrKFu2LPr06YOQkBAcOXJE60wvus7CSCkxaNAgNGzYEKNGjXriMTs7O/z2228ANJPKbt26NX+0rwsvkvdpZ86cQW5uLuzs7Ar9eT+toPefyJQpOcJ7CKC9lPKuEMIWQKIQYreUsuD/fbqk4wu9hoaGPlM4AQEBmD59OurXr4+FCxdi4MCBaNSoEYYNG5b/nIIey/vFnadZs2bo379//uaywYMH4+2338batWtha2uL3r17Izc3Fy1btsTevXvRvn37Yuf39PTEkiVL0LBhQ9SvXz9/k9bjk8B6eXlh1qxZ+PLLL6FSqaBWq2Fra4uFCxfC2dk5fxLYKlWqFGsS2MK+PwA4ceIEQkJCYGVlBVtbWyxevFjrTP/5z3+Kvc7HN/E+LSkpCevWrYODgwMcHR0BAF9//TU6deqEXr164ZdffoG9vT1ee+01hIaGws7O7pll5G3Sffx9Dw4OfuH3qKi8T69PSok1a9bA2tq60J/30wp6/4lMmVFMACuEKAsgEcAwKeWBwp5nahPApqamonPnzvlHT2r7GJEpMeb/g2S8tm/fjhYtWqBq1aolXpZJTAArhLAWQhwD8DeAqILKTgjxvhDisBDi8M2bNw0fkoiIdGrJkiXw9/fH5MmTDbpeRQtPSpkrpXQEUB1ACyGEfQHPWSaldJJSOlWuXNnwIUugVq1ahY7ginqMiMgcSSnx1VdfYdiwYfD29tb6gDddMYrTEqSU/wCIBeCpdBYiItI9tVqNUaNGYeLEiejbty+2bt2KsmXLGjSDkkdpVhZCvPLo85cAuANQ9iKURESkc9nZ2ejXrx++++47jBw5EqtXr4atra3Bcyh5lGZVAGuEENbQFO9PUsodL7IgKeUz57QRkf4Zw0FvZNzu37+PwMBA7Ny5E1999RXGjRun2O9rxQpPSvkbgLdLupwyZcogLS0NdnZ2LD0iA5JSIi0tDWXKlFE6Chmpf/75B126dEFSUhKWLFmCoUOHKprH5K+0Ur16dVy9ehU8gpPI8MqUKYPq1asrHYOM0I0bN+Dh4YHTp09j06ZN6N69u9KRTL/wbG1tUbt2baVjEBHRIxcuXIC7uzv++usv7Ny5E+7u7kpHAmAGhUdERMbjt99+g4eHB7KysrB37978KwQZA6M4LYGIiExfYmIi2rRpA2trayQmJhpV2QEsPCIi0oGdO3dCpVLh9ddfR1JSklFebo6FR0REJbJ+/Xr4+vqiUaNGSExMRM2aNZWOVCAWHhERvbDvv/8effr0QZs2bbB3714Y8yUgWXhERFRsUkpMmjQJH3/8Mfz9/bFr1y5UrFhR6VhF4lGaRERULLm5ufjoo4+wePFiDBo0CEuWLIGNjfHXCUd4RESktaysLAQFBWHx4sX47LPPsHz5cpMoO4AjPCIi0tK9e/fQtWtXREZGYtasWRgzZozSkYqFhUdERM+VlpYGb29vHDp0CD/88AMGDBigdKRiY+EREVGRrl27BpVKhT/++ANbtmyBn5+f0pFeCAuPiIgKdfbsWahUKty+fRt79uyBm5ub0pFeGAuPiKiEUi7dwf4LaXCuY4fmNSspHUdnjhw5Ak9PTwBAXFwcmjVrpnCikmHhERGVQMqlOwhasR9ZOWqUsrHC+sHOZlF6cXFx8PHxwauvvorIyEjUq1dP6UglxtMSiIhKYP+FNGTlqKGWQHaOGvsvpCkdqcS2bdsGT09P1KhRA0lJSWZRdgALj4ioRJzr2KGUjRWsBWBrYwXnOnZKRyqRVatWISAgAI6OjkhISMAbb7yhdCSd4SZNIqISaF6zEtYPdjaLfXizZs3Cp59+CpVKhS1btqB8+fJKR9IpFh4RUQk1r1nJpItOSomxY8di5syZ6NGjB9auXYtSpUopHUvnWHhERBYsJycHwcHBWLlyJYYNG4b58+fD2tpa6Vh6wX14REQW6sGDBwgMDMTKlSsxadIkLFy40GzLDuAIj4jIIv3777/w8/NDbGws5s2bhxEjRigdSe9YeEREFubmzZvw8vLC8ePH8eOPPyIoKEjpSAbBwiMisiCXL1+Gu7s7rly5gu3bt6NTp05KRzIYFh4RkYU4deoUVCoV7t27h6ioKLRq1UrpSAbFg1aIiCzAgQMH4OrqitzcXOzbt8/iyg5g4RERmb2oqCh06NABr7zyCpKSktCkSROlIymChUdEZMY2b94Mb29v/Pe//0VSUhLq1KmjdCTFsPCIiMzUkiVL0KNHD7z77ruIi4vDf/7zH6UjKYqFR0RkZqSU+OqrrzBs2DB4e3sjIiICr7zyitKxFMfCIyIyI2q1GqNGjcLEiRPRp08fbN26FWXLllU6llHgaQlERGYiOzsbAwcOxI8//oiPP/4Yc+bMgZUVxzV5WHhERGYgMzMTgYGB2LFjB7788kuMHz8eQgilYxkVFh4RkYn7559/4OPjg8TERCxevBjBwcFKRzJKLDwiIhN248YNeHp64tSpU9i4cSMCAwOVjmS0WHhERCbqwoULUKlUuHHjBnbu3Al3d3elIxk1Fh4RkQn67bff4OHhgaysLMTExODdd99VOpLR4+E7REQmJikpCW3btoW1tTUSEhJYdlpi4RERmZBdu3bB3d0dVapUQVJSEho1aqR0JJPBwiMiMhHr16+Hr68vGjVqhMTERNSsWVPpSCaFhUdEpik5GZg+XXNrAb7//nv06dMHrq6u2Lt3LypXrqx0JJPDg1aIyPQkJwMdOgBZWUCpUkBMDODionQqvZBSYvLkyZg2bRr8/f2xYcMGlClTRulYJokjPCIyPXFxmrLLzdXcxsUpnUgvcnNzMXz4cEybNg2DBg3CTz/9xLIrARYeEZkeNzfNyM7aWnPr5qZ0Ip3LyspCUFAQFi9ejE8//RTLly+HjQ03ypWEYu+eEKIGgLUAXgcgASyTUs5TKg8RmRAXF81mzLg4TdmZ2ebMe/fuoWvXroiMjMTMmTMREhKidCSzoOSfCzkARkspjwghKgBIEUJESSlPKZiJiEyFi4vZFR0A3L59G97e3jh48CBWrlyJgQMHKh3JbChWeFLKPwH8+ejzDCHEaQBvAGDhEZFFunbtGjw8PHD+/Hls2bIFfn5+SkcyK0axQVgIUQvA2wAOKJuEiIxWcrLZbsIEgHPnzsHd3R23b9/G7t270a5dO6UjmR3FC08IUR7AFgAjpZT/FvD4+wDeB4A333zTwOmIyCiY+WkIR48ehYeHBwAgLi4OzZo1UziReVL0KE0hhC00ZbdeSrm1oOdIKZdJKZ2klE480ZLIQpnxaQhxcXFo27YtXnrpJSQmJrLs9EixwhOaqXhXAjgtpZyjVA4iMgFmehrC9u3b4enpiRo1aiApKQn16tVTOpJZU3KE1wpAXwDthRDHHn10UjAPERmrvNMQpk0zm82Zq1atQteuXeHo6Ij4+HhUr15d6UhmT8mjNBMBCKXWT0QmxoxOQ/j2228REhIClUqFLVu2oHz58kpHsgi80goRUUlpeSFrKSXGjh2LkJAQ9OjRA+Hh4Sw7A1L8KE0iIpOm5RGkOTk5CA4OxsqVKzFs2DDMnz8f1tbWCgS2XBzhERGVhBZHkD548ACBgYFYuXIlJk2ahIULF7LsFMARHhFRSeQdQZo3wnvqCNJ///0Xfn5+iI2Nxbx58zBixAhFYhILj4ioZIq4kPXNmzfh5eWFY8eO4ccff0RQUJBiMYmFR0RUcgUcQXr58mW4u7vj8uXL2L59O7y9vRUKR3lYeEREOnb69GmoVCpkZGQgKioKrVu3VjoSgQetEBHp1MGDB+Hq6oqcnBzEx8ez7IwIC4+ISEeio6PRvn17vPzyy0hMTESTJk2UjkSPYeEREelAWFgYOnXqhLp16yIxMRF169ZVOhI9hYVHRFRCS5cuRWBgIN59913s27cPVatWVToSFYCFR0T0gqSU+PrrrxEcHIxOnTohIiICr7zyitKxqBAsPCKiF6BWqzF69GhMmDABffr0wc8//4yyZcsqHYuKwNMSiIiKKTs7G4MGDcK6devw8ccfY86cObCy4vjB2LHwiIiKITMzE4GBgdixYwe+/PJLjB8/Hpr5rMnYsfCIiLT0zz//wMfHB4mJiVi8eDGCg4OVjkTFwMIjItLCjRs34OnpiVOnTmHjxo0IDAxUOhIVEwuPiExSyqU72H8hDc517NC8ZiW9ruvChQtQqVS4ceMGdu7cCXd3d72uj/SDhUdEJifl0h0ErdiPrBw1StlYYf1gZ72V3m+//QYPDw9kZWUhJiYG7777rl7WQ/rHw4qIyOTsv5CGrBw11BLIzlFj/4U0vawnKSkJbdu2hbW1NRISElh2Jo6FR0Qmx7mOHUrZWMFaALY2VnCuY6fzdezatQvu7u6oUqUKkpKS0KhRI52vgwyLmzSJyOQ0r1kJ6wc7620f3oYNG9CvXz80adIEu3fvRpUqVXS6fFIGC4+ITFLzmpX0st9u/vz5GDFiBNzc3LB9+3ZUrFhR5+sgZXCTJhERNNfFnDx5MkaMGAE/Pz/s3r2bZWdmOMIjIounVqvx0UcfYdGiRRg4cCCWLl0KGxv+ejQ3HOERkUXLyspCUFAQFi1ahE8//RQrVqxg2Zkp/lSJyGLdu3cPAQEBiIiIwMyZMxESEqJ0JNIjFh4RWaTbt2/D29sbBw8exMqVKzFw4EClI5GesfCIyOJcu3YNHh4eOH/+PLZs2QI/Pz+lI5EBsPCIyKKcO3cO7u7uuH37Nnbv3o127dopHYkMhIVHRCZBFxeLPnr0KDw8PAAAcXFxaNasmS4jkpFj4RGR0dPFxaLj4uLg4+ODSpUqISoqCvXq1dNTWjJWPC2BiIxeSS8WvX37dnh6eqJ69epISkpi2VkoFh4RGb2SXCx69erV6Nq1KxwdHZGQkIDq1avrMSkZM27SJCKj96IXi549ezbGjBkDd3d3bN26FeXLl9dzUjJmLDwiMgnFuVi0lBLjx4/HN998g8DAQKxduxalS5fWc0Iydiw8IjIrubm5CA4OxooVKxAcHIwFCxbA2tpa6VhkBLgPj4jMxoMHDxAYGIgVK1bg888/x6JFi1h2lI8jPCIyCxkZGfDz88PevXvx3Xff4eOPP1Y6EhkZFh4RmbybN2+iU6dOOHr0KNatW4c+ffooHYmMEAuPiEza5cuXoVKpcOnSJWzfvh3e3t5KRyIjxcIjIpN1+vRpqFQqZGRkICoqCq1bt1Y6EhkxHrRCRCbp4MGDcHV1RU6h1VrPAAAgAElEQVRODuLj41l29FwsPCIyOdHR0Wjfvj1efvllJCYmokmTJkpHIhPAwiMikxIWFoZOnTqhbt26SExMRN26dZWORCaChUdEJmPp0qUIDAxEixYtsG/fPlStWlXpSGRCFC08IcQPQoi/hRAnlcxBRMZNSomvv/4awcHB8PLyQmRkJF555RWlY5GJUXqEtxqAp8IZiMiIqdVqjB49GhMmTECfPn2wbds2lC1bVulYZIIULTwpZTyA20pmICLjlZ2djQEDBmDu3LkYMWIE1qxZA1tbW82DycnA9OmaWyItGP15eEKI9wG8DwBvvvmmwmmIyFAyMzPRo0cPhIeHY9q0aZgwYQKEEJoHk5OBDh2ArCygVCkgJgZwcVE2MBk9pTdpPpeUcpmU0klK6VS5cmWl4xCRAaSnp8PDwwM7duzAokWLMHHixP8vOwCIi9OUXW6u5jYuTqmoZEKMfoRHRJblxo0b8PT0xKlTp7Bx40YEBgY++yQ3N83ILm+E5+Zm6Jhkglh4RGQ0Ll68CHd3d/z555/YsWMHVCpVwU90cdFsxoyL05QdN2eSFhQtPCFEKAA3AK8JIa4CmCylXKlkJiJSxokTJ+Dh4YGHDx8iJiYGzs7ORb/AxYVFR8WiaOFJKXspuX4iMg6//vorvL29Ua5cOSQkJKBRo0ZKRyIzZPQHrRCRedu1axc6duyIKlWqICkpiWVHesPCIyLFbNiwAb6+vmjYsCESEhJQs2ZNpSORGWPhEZEi5s+fj6CgILRu3RqxsbGoUqWK0pHIzLHwiMigpJSYPHkyRowYAT8/P+zevRsVK1ZUOhZZgOcWnhDCWQhxSAhxVwiRJYTIFUL8a4hwRGRe1Go1PvzwQ0ydOhUDBgzA5s2bUaZMGaVjkYXQZoS3AEAvAOcAvARgMICF+gxFROYnKysLQUFBWLRoEUJCQrBy5UrY2PBUYDIcrTZpSinPA7CWUuZKKVeBMxwQUTHcu3cPPj4+2LhxI2bMmIGZM2c+eakwIgPQ5s+r+0KIUgCOCSFmAvgT3PdHRFq6ffs2OnfujAMHDmDFihUYNGiQ0pHIQmlTXH0fPe9DAPcA1ADQVZ+hiMg8XLt2DW3atEFKSgrCwsJYdqQobQrPT0r5QEr5r5TyCynlKACd9R2MiEzbuXPn0Lp1a1y+fBl79uyBv7+/0pHIwmlTeP0KuK+/jnMQkRk5evQoWrdujbt37yI2Nhbt2rUr+UI54SuVUKH78IQQvQD0BlBbCPHLYw9VBGcpJ6JC7Nu3Dz4+PnjllVcQGRmJ+vXrl3yhnPCVdKCog1Z+heYAldcAzH7s/gwAv+kzFJFFS0422WlvfvnlFwQGBqJOnTqIjIxE9erVdbPggiZ8NbH3hpRXaOFJKS8BuATARQhRE8BbUspoIcRL0JyPl2GgjESWw4RHMqtXr8bgwYPh5OSEnTt3ws7OTncL54SvpAPaXGllCIAwAEsf3VUdwDZ9hiKyWAWNZEzA7NmzMWDAALRv3x7R0dG6LTvg/yd8nTbNpP4IIOOizXl4wwG0AHAAAKSU54QQvMorkT4UdySj8OZPKSXGjx+Pb775BoGBgVi7di1Kly6tn5VxwlcqIW0K76GUMivvqghCCBsAUq+piCxV3khGmxJTePNnbm4ugoODsWLFCgQHB2PBggWwtrY22PqJikubwtsnhBgP4CUhhDuADwCE6zcWkQXTdiSj4IEcDx48QFBQELZu3YqJEydi6tSpvFQYGT1tCm8sgEEATgAYCmAXgBX6DEVEWlDoQI6MjAz4+flh7969mDt3LkaOHGmQ9RKV1HMLT0qpFkKsgWYfngTwPyklN2kSKa04mz915ObNm+jUqROOHj2KtWvXom/fvnpfJ5GuPLfwhBDeAJYA+AOAgOZE9KFSyt36DkdEz2HAAzkuX74MlUqFS5cuYdu2bejcmVcYJNOizSbN2QDaPZoiCEKIugB2AmDhEVmI06dPQ6VSISMjA5GRkXB1dVU6ElGxaVN4GXll98gF8KRzIotx6NAheHl5wcbGBvv27UPTpk2VjkT0Qoq6lmbeFECHhRC7APwEzT687gAOGSAbESksOjoafn5+qFKlCqKiolC3bl2lIxG9sKJGeF0e+/wvAG0ffX4TQBm9JSIioxAWFoagoCDUr18fERERqFq1qtKRiEqkqGtpDjBkECIyHsuWLUNwcDBatmyJ8PBwVKpUSelIRCWmzXx4RGQhpJSYPn06hg4dCi8vL0RGRrLsyGyw8IgIAKBWqzF69GiMHz8effr0wbZt21C2bFmlYxHpjDZHaRKRmcvOzsbgwYOxdu1ajBgxAnPnzoWVFf8eJvNS1FGao4p6oZRyju7jEJGhZWZmokePHggPD8e0adMwYcIEXheTzFJRI7wKj27rA3gHwC+Pvu4C4KA+QxGRYaSnp6NLly5ITEzEokWLMGzYMKUjEelNUUdpfgEAQoh4AM2klBmPvp4CzZVWiMiE3bhxA56enjh16hRCQ0PRo0cPpSMR6ZU2+/BeB5D12NdZj+4jIhN18eJFuLu7488//0R4eDg8PDyUjkSkd9oU3loAB4UQPz/62g/AGv1FIiJ9OnHiBDw8PPDgwQPExMTA2dlZ6UhEBqHN9EBfCSF2A8i7WuwAKeVR/cYiIn349ddf4e3tjbJlyyIhIQGNGzdWOhKRwWh73HFZAP9KKecBuCqEqK3HTESkB7t370bHjh1RuXJlJCUlsezI4jy38IQQkwF8BmDco7tsAfyoz1BEpFuhoaHw8fFBgwYNkJiYiFq1aikdicjgtBnh+QPwAXAPAKSU1/H/pywQkZFbsGABgoKC0KpVK8TFxaFKlSpKRyJShDaFlyWllNBMDQQhRDn9RiIiXZBSYsqUKfjoo4/g6+uLPXv2oGLFikrHIlKMNoX3kxBiKYBXhBBDAEQDWKHfWERUEmq1Gh999BG++OILDBgwAJs3b0aZMpzViyybNkdpfiuEcAfwLzRXXZkkpYzSezIieiFZWVno378/QkNDERISghkzZvBSYUTQovCEEDOklJ8BiCrgPiIyIvfu3UO3bt2wZ88ezJgxA59++qnSkYiMhjabNN0LuM9L10GIqGRu374Nd3d3REZGYsWKFSw7oqcUNVvCMAAfAKgjhPjtsYcqAEjSdzAi0t61a9fg4eGBc+fOISwsDP7+/kpHIjI6RW3S3ABgN4DpAMY+dn+GlPK2XlMRkdbOnTsHlUqFW7duYc+ePWjXrp3SkYiMUqGbNKWU6VLKVCllLynlJQCZ0JyaUF4I8aYuVi6E8BRC/E8IcV4IMfb5ryCixx09ehStW7fG3bt3ERsby7IjKoI2V1rpIoQ4B+AigH0AUqEZ+ZWIEMIawEJo9gc2AtBLCNGopMslshTx8fFwbdsW2bDGkk074eTkpHQkIqOmzUErXwJwBnBWSlkbQAcA+3Ww7hYAzkspL0gpswBsBOCrg+USmbSUS3ewMPY8Ui7dKfQ5v/zyC1QqD2SXfgXlu32NCbFpRT6fiLQrvGwpZRoAKyGElZQyFoAu/pR8A8CVx76++ug+IouVcukOglbsx+zI/yFoxf4CS2zNmjXo2rUr/lO7Hv4TNANWFSojO0eN/RfSFEhMZDq0Kbx/hBDlAcQDWC+EmIdH19U0BCHE+0KIw0KIwzdv3jTUaokUsf9CGrJy1FBLFFhic+bMQf/+/dG+fXus27IDL1V4BdYCsLWxgnMdO4VSE5kGbSaA9QXwAMAnAIIAvAxgqg7WfQ1Ajce+rv7ovidIKZcBWAYATk5OUgfrJTJaznXsUMrGCtk56idKTEqJCRMmYPr06ejevTvWrVuH0qVLY/3g8th/IQ3OdezQvGYlhdMTGTehuS60Fk8UoiIeK8iSnpoghLABcBaafYLXABwC0FtK+Xthr3FycpKHDx8uyWqJjF7KpTtPlFhubi6GDRuG5cuXY+jQoVi4cCGsra2VjklkNIQQKVLK5+5q0+bSYkMBfAHNKE8NQEBzekKdkgSUUuYIIT4EEAHAGsAPRZUdkaVoXrNS/mjt4cOHCAoKwpYtWzBx4kRMnTqV18UkekHabNIcA8BeSnlL1yuXUu4CsEvXyyXSm+RkIC4OcHMDXFz0uqqMjAz4+/sjJiYGc+fOxciRI/W6PiJzp03h/QHgvr6DEBm95GSgQwcgKwsoVQqIidFb6d26dQteXl44evQo1q5di759++plPUSWRJvCGwfgVyHEAQAP8+6UUo7QWyoiYxQXpym73FzNbVycXgrv8uXLUKlUuHTpErZt24bOnTvrfB1ElkibwlsKYC+AE9DswyOyTG5umpFd3gjPzU3nqzh9+jRUKhUyMjIQGRkJV1dXna+DyFJpU3i2UspRek9CZOxcXDSbMfW0D+/QoUPw8vKCjY0N9u3bh6ZNm+p0+USWTpvC2y2EeB9AOJ7cpMkZE8jyuLgUXHQlPJglOjoafn5+qFKlCqKiolC3bt0SRyWiJ2lTeL0e3Y577L4Sn5ZAZDZKeDBLWFgYgoKCUL9+fURERKBq1ap6DEtkuZ57aTEpZe0CPlh2RHkKOphFS8uXL0dgYCCcnJywb98+lh2RHhU143l7KeVeIUTXgh6XUm7VXywiE/ICB7NIKTFjxgyMGzcOXl5eCAsLQ9myZfUelciSFbVJsy00R2d2KeAxCYCFRwQU+2AWtVqNkJAQzJkzB0FBQVi1ahVsbW0NEpXIkhVaeFLKyY8+nSqlvPj4Y0KI2npNRWRqCjuY5Sk5OTkYPHgw1qxZg48++gjfffcdrKy0mbSEiEpKm/9pWwq4L0zXQYjMXWZmJgICArBmzRpMnToV8+bNY9kRGVBR+/AaAGgM4OWn9uNVBFBG38GIzEl6ejp8fHyQkJCARYsWYdiwYUpHIrI4Re3Dqw+gM4BX8OR+vAwAQ/QZisic/PXXX/D09MTvv/+O0NBQ9OjRQ+lIRBapqH142wFsF0K4SCmTDZiJyGxcvHgRKpUK169fR3h4ODw8PJSORGSxtNmB4C+EqCiEsBVCxAghbgoh+ug9GZGJO3nyJFq1aoW0tDTExMSw7IgUpk3hqaSU/0KzeTMVwH8BhOgzFJGp+/XXX+Hq6gohBBISEuDs7Kx0JCKLp03h5Z0g5A1gs5QyXY95iEze7t270bFjR1SuXBlJSUlo3Lix0pGICNoVXrgQ4gyA5gBihBCVATzQbywi0xQaGgofHx80aNAAiYmJqFWrltKRiOgRba6lORZASwBOUspsaGY/99V3MCJTs2DBAgQFBaFVq1aIi4tDlSpVlI5ERI8ptPCEEJ8+9mUHKWUuAEgp7wHgbOdEj0gpMWXKFHz00Ufw8fHBnj17ULFiRaVjEdFTihrh9Xzs83FPPeaphyxEGsnJwPTpmlsjp1arMWLECHzxxRfo378/wsLCUKYMr8tAZIyKOvFcFPJ5QV8T6UYJ55YzpKysLPTv3x+hoaEYM2YMZs6cCSH4X4PIWBU1wpOFfF7Q10S6UYK55Qzp3r178PX1RWhoKGbMmIFZs2ax7IiMXFEjvKZCiH+hGc299OhzPPqa22xIP15gbjlDu337Njp37owDBw5g+fLlGDx4sNKRiEgLRV1azNqQQYgAFHtuOUO7fv06PDw8cPbsWWzevBlduxY4P7LhJCcb7XtFZGyKGuERKUPLueUM7fz583B3d8etW7ewe/dutG/fXtlAJrS/k8gYcDIuIi0cPXoUrVq1wt27dxEbG6t82QEms7+TyFiw8IieIz4+Hm5ubihdujQSExPh5OSkdCSNvP2d1tZGu7+TyJhwkyZREX755Rf06NEDtWvXRkREBGrUqKH/lWq7X87I93cSGRsWHlEh1qxZg0GDBqF58+bYtWsX7Ozs9L/S4u6XM9L9nUTGiJs0iQowZ84c9O/fH+3bt0dMTIxhyg7gfjkiPWLhET1GSonx48dj9OjR6N69O8LDw1G+fHnDBeB+OSK94SZNokdyc3MxbNgwLF++HEOHDsXChQthbW3g01G5X45Ib1h4RAAePnyIoKAgbNmyBRMmTMC0adOUu1QY98sR6QULjyxeRkYG/P39ERMTgzlz5uCTTz5ROhIR6QELjyzarVu30KlTJxw5cgRr1qzBe++9V+TzUy7dwf4LaXCuY4fmNSsZKCUR6QILj0yPjq4feeXKFahUKqSmpuLnn39Gly5dinx+yqU7CFqxH1k5apSyscL6wc4sPSITwsIj06Kj60eeOXMGKpUK6enpiIyMhKur63Nfs/9CGrJy1FBLIDtHjf0X0vILjyM/IuPHwiPTUtB5asUsvEOHDsHLyws2NjbYt28fHB0dtXqdcx07lLKxQnaOGrY2VnCuozk3jyM/ItPAwiPTUsL58mJiYuDn54fKlSsjKioKdevW1fq1zWtWwvrBzs+M5Ioa+RGR8WDhkWkpwXlqW7ZsQe/evVG/fn1ERESgatWqxV5985qVnimzwkZ+RGRchJRS6Qxac3JykocPH1Y6Bpmg5cuXIzg4GM7OztixYwcqVdLtCIz78IiUI4RIkVI+dxoTjvDIrEkpMWPGDIwbNw5eXl4ICwtD2bJldb6egkZ+RGRceC1NMltqtRpjxozBuHHjEBQUhO3bt+ul7IjINHCER2YpJycHgwcPxpo1a/DRRx/hu+++g5UV/74jsmT8DUBmJzMzEwEBAVizZg2mTp2KefPmseyISJnCE0J0F0L8LoRQCyGeu6ORSFvp6enw9PREeHg4Fi5ciM8//1y5i0ATkVFRapPmSQBdASxVaP1khv766y94enri5MmT2LBhA3r27Kl0JCIyIooUnpTyNAD+5U06k5qaCnd3d1y/fh3h4eHw9PRUOhIRGRmjP2hFCPE+gPcB4M0331Q4DRmjkydPwsPDA5mZmYiOjoYL55IjogLobR+eECJaCHGygA/f4ixHSrlMSukkpXSqXLmyvuKSiUpOTkabNm0AAAkJCSw7IiqU3kZ4UsqO+lo2EQDs2bMHAQEBqFatGqKiolCrVi2lIxGREeOx2mSSQkND0aVLF9SvXx+JiYksOyJ6LqVOS/AXQlwF4AJgpxAiQokcZISSk4Hp0zW3hVi4cCGCgoLQqlUrxMbG4vXXXzdgQCIyVUodpfkzgJ+VWDcZsUeTu8qHWcixtcUfG7ahQVeP/IellJg6dSqmTJkCX19fbNy4EWXKlFEwMBGZEm7SJOMRFwf5MAtCnQuRlYVdC0KRcukOAM11MUeMGIEpU6agf//+CAsLY9kRUbGw8Mh4uLkhx9YWOcIK2dY2+LW6vWZy1aws9OnTBwsWLMCYMWPwww8/wMbG6M+oISIjw98aZDxcXPDHhm3YtSAUv1a3x8majTHyP2Xg6+uLPXv2YMaMGfj000+VTklEJoqFR0alQVcP3GveAqUvpOGDV63w6aBAHDhwAMuXL8fgwYOVjkdEJoyFR0anec1KqGqbCQ8PD5w9exabN29G165dlY5FRCaOhUdG5/z583B3d8etW7ewe/dutG/fXulIRGQGWHhkVI4dOwYPDw+o1WrExsbCyYmzRxGRbvAoTTIa8fHxaNu2LUqXLo2EhASWHRHpFAuPjEJ4eDg8PDxQrVo1JCUloUGDBkpHIiIzw8Ijxa1duxb+/v5wcHBAQkICatSooXQkIjJDLDxS1Ny5c9GvXz+0a9cOMTExeO2115SORERmioVHipBSYsKECRg1ahS6deuGHTt2oEKFCkrHIiIzxqM0yeByc3PxwQcfYNmyZRg6dCgWLlwIa2trpWMRkZnjCI8M6uHDh+jZsyeWLVuGCRMmYPHixSw7IjIIjvDIYDIyMuDv74+YmBjMmTMHn3zyidKRiMiCsPDIIG7duoVOnTrhyJEjWLNmDd577z2lI+lWcjIQFwe4uQEuLkqnIaICsPBI765cuQKVSoXU1FT8/PPP6NKli9KRdOvRxLXIygJKlQJiYlh6REaI+/BIr86cOYNWrVrh+vXriIiI0EnZpVy6g4Wx5/MnhzWo5GRg+nTNbZ64OE3Z5eZqbuPiDJ+LiJ6LIzzSm0OHDsHLywvW1tbYt28fHB0dS7zMlEt3ELRiP7Jy1ChlY4X1g53RvGYlHaTVQmEjOTc3zdd597u5GSYPERULR3ikFzExMWjfvj0qVKiApKQknZQdAM0M6DlqqCWQnaPG/gtpOlmuVgobybm4aMpv2jRuziQyYhzhkc5t3boVvXr1Qr169RAREYFq1arpbNnOdexQysYK2Tlq2NpYwbmOnc6W/VxFjeRcXFh0REaOhUc6tWLFCgwdOhTOzs7YsWMHKlXS7ebG5jUrYf1gZ+y/kAbnOnaG25wJ/P9IjkdjEpkkIaVUOoPWnJyc5OHDh5WOQQWQUmLmzJkYO3YsvLy8sHnzZpQrV07pWERkAYQQKVLK584nxn14VGJSSoSEhGDs2LHo3bs3tm/fzrIjIqPDTZpUIjk5ORgyZAhWr16NDz/8EPPmzYOVFf+OIiLjw99M9MIyMzMREBCA1atX44svvsD333/PsiMio8URHr2Q9PR0+Pr6Ij4+HgsXLsQHH3ygdCQioiKx8KjY/vrrL3h6euLkyZPYsGEDevbsqXQkIqLnYuFRsaSmpsLd3R3Xr19HeHg4PD09lY5ERKQVFh5p7eTJk/Dw8EBmZiaio6PhwvPQiMiE8AgD0kpycjLatGkDAEhISGDZEZHJYeHRc+3ZswcdO3aEnZ0dkpKS0LhxY6UjEREVGwuPihQaGoouXbqgXr16SExMRK1atZSORET0Qlh4VKhFixYhKCgILVu2RFxcHF5//XWlIxERvTAWHj1DSompU6di+PDh6NKlC/bs2YOXX37ZcAEKmmSViKiEeJQmPUGtVmPkyJGYP38++vXrhxUrVsDGxnD/TM5sjUDd3n6wyc6GKF2K88sRkc5whEf5srOz0bdvX8yfPx+jR4/GDz/8YNCyS7l0B7sWhEJkZUGocyEfn2SViKiEWHgEALh//z58fX2xYcMGfPPNN5g1a5bBr4u5/0IakqrbI9vaBjnCCjk2tk9OskpEVALcpEm4c+cOOnfujP3792PZsmUYMmSIIjmc69hhfs3G6NvrK7S8ehKdPuyFBtycSUQ6wsKzcNevX4eHhwfOnj2Ln376CQEBAYpl+f/ZzN+Cc51BaGDI2cyJyOyx8CzY+fPnoVKpcPPmTezevRvt27dXOhKa16yE5iw6ItIDFp6FOnbsGDw9PZGbm4vY2Fg4OTkpHYmISK940IoFio+PR9u2bVGqVCkkJCSw7IjIIrDwLEx4eDg8PDxQrVo1JCUloUGDBkpHIiIyCBaeBVm7di38/f3h4OCAhIQE1KhRQ+lIREQGo0jhCSFmCSHOCCF+E0L8LIR4RYkclmTu3Lno168f3NzcEBMTg9dee03pSEREBqXUCC8KgL2UsgmAswDGKZTD7EkpMWHCBIwaNQoBAQHYuXMnKlSooHQsIiKDU6TwpJSRUsqcR1/uB1BdiRzmLjc3F8HBwfj666/x/vvvY9OmTShdurTSsYiIFGEM+/AGAtitdAhz8/DhQ/Ts2RPLli3D+PHjsWTJElhbWysdi4hIMXo7D08IEQ3gPwU8NEFKuf3RcyYAyAGwvojlvA/gfQB488039ZDU/Ny9exf+/v6Ijo7G7NmzMWrUKKUjEREpTm+FJ6XsWNTjQoj+ADoD6CCllEUsZxmAZQDg5ORU6PNI49atW/D29kZKSgrWrFmD9957T+lIRERGQZErrQghPAF8CqCtlPK+EhnM0ZUrV6BSqZCamoqff/4ZXbp0UToSEZHRUOrSYgsAlAYQJYQAgP1SymCFspiFM2fOQKVSIT09HREREWjTpo3SkYiIjIoihSel/K8S6zVXhw8fhpeXF6ysrLBv3z44OjoqHUm/kpM1E8O6uXE2dCLSGi8ebeL27t0LX19fvPbaa4iKisJ//2vmf0skJwMdOgBZWUCpUkBMDEuPiLRiDKcl0AvaunUrvLy8UKtWLSQlJZl/2QGakV1WFpCbq7mNi1M6ERGZCBaeiVqxYgW6d+8OJycnxMfHo1q1akpHMgw3N83Iztpac+vmpnQiIjIR3KRpYqSUmDlzJsaOHQsvLy9s3rwZ5cqVUzqW4bi4aDZjch8eERUTC8+ESCkREhKC2bNno3fv3li9ejVsbW2VjmV4Li4sOiIqNhaeicjJycGQIUOwevVqfPjhh5g3bx6srLhFmohIW/yNaQIePHiAbt26YfXq1ZgyZQq+//57lh0RUTFxhGfk0tPT4evri/j4eCxYsADDhw9XOhIRkUli4Rmxv//+G56enjhx4gTWr1+PXr16KR2JiMhksfCMVGpqKlQqFa5evYrw8HB4enoqHYmIyKSx8IzQ77//DpVKhczMTMTExMCFRyQSEZUYj3wwMsnJyXB1dYWUEvHx8Sw7IiIdYeEZkYiICHTs2BF2dnZISkqCvb290pGIiMwGC89IbNy4EV26dEG9evWQmJiI2rVrKx2JiMissPCMwKJFi9C7d2+4uLggLi4Or7/+utKRiIjMDgtPQVJKTJ06FcOHD0eXLl2wZ88evPzyy0rHIiIySzxKUyFqtRojR47E/Pnz0a9fP6xYsQI2NvxxEBHpC0d4CsjOzkbfvn0xf/58jB49Gj/88APLjohIz/hb1sDu37+Pbt26Yffu3Zg+fTo+++wzCCGUjkVEZPZYeAZ0584ddO7cGfv378eyZcswZMgQpSMZn+RkznVHRHrBwjOQP//8EyqVCmfPnsVPP/2EgIAApSMZn+RkoEMHICtLM5t5TAxLj4h0hvvwDOCPP/5Aq1atkJqail27drHsChMXpym73FzNbVyc0omIyIxwhKdnx48fh4eHB3JycrB371688847SkcyXm5umpFd3gjPzU3pRERkRlh4epSQkIAuXbqgQoUKiIuLQ4MGDZSOZNxcXDSbMbkPj4j0gIWnJzt27ED37t1Rq1YtREZGokaNGiqrNDUAAAmKSURBVEpHMg0uLiw6ItIL7sPTg7Vr18LPzw8ODg5ISEhg2RERGQEWno5999136NevH9zc3BATE4PXXntN6UhERAQWns5IKTFx4kR88sknCAgIwM6dO1GhQgWlYxER0SPch6cDubm5GD58OJYuXYr3338fixYtgrW1tdKxiIjoMRzhldDDhw/Rs2dPLF26FOPHj8eSJUtYdkRERogjvBK4e/cu/P39ER0djdmzZ2PUqFFKRyIiokKw8F7QrVu34O3tjZSUFKxevRr9+vVTOhIRERWBhfcCrly5ApVKhYsXL2Lr1q3w8fFROhIRET0HC6+Y/ve//8Hd3R3p6emIiIhA27ZtlY5ERERaYOEVQ0pKCjw9PWFlZYW4uDi8/fbbSkciIiIt8ShNLcXGxsLNzQ3ly5dHYmIiy46IyMSw8LSwdetWeHp6olatWkhKSsJbb72ldCQiIiomFt5zrFy5Et27d0fz5s2xb98+VKtWTelIRET0Alh4RZg5cyYGDx4MlUqFqKgovPrqq0pHIiKiF8TCK4CUEiEhIfjss8/Qq1cvbN++HeXKlVM6FhERlQCP0nxKTk4O3n//faxatQoffvgh5s2bBysr/l1ARGTq+Jv8MQ8ePEC3bt2watUqTJkyBd9//z3LjojITHCE90h6ejp8fX0RHx+PBQsWYPjw4UpHIiIiHWLhAfj777/h6emJEydOYP369ejVq5fSkYiISMcsvvBSU1OhUqlw9epVhIeHw9PTU+lIRESkBxZdeL///jtUKhXu37+P6OhotGzZUulIRESkJ4ockSGEmCaE+E0IcUwIESmEMPjZ3MnJyXB1dYWUEvHx8Sw7IiIzp9QhiLOklE2klI4AdgCYZMiVR0REoGPHjnj11VeRlJQEBwcHQ66eiIgU8H/t3X+oX3Udx/Hna1PbddkMHWZu10Zd1o8RRmvIElm0bERzjSiMhkYsGRhpEGmuWlYD+0FIQZC0QcKwCbNctuEaSeVitTu5Tt1mrcZQqczm0pUg677643sufLvsbnfde75n93xeD7hwzvf7Oee8P9x7v697zvnc82kk8Gy/2LU6E3Cvjr1582aWL1/OwMAAu3btYt68eb06dERENKixe3iS1gPXA/8E3tOLYx4+fJhVq1axePFitm7dyqxZs3px2IiIOAvIrufkStJO4HUneWut7Qe62n0BmGF73Rj7uRG4EaC/v/+dR44cmVBd27dvZ8mSJfT19U1oPxERcXaQtNf2wtO2qyvwxktSP7DN9oLTtV24cKEHBwd7UFVEREwV4w28pkZpdk8otwI42EQdERFRjqbu4d0paT4wDBwB1jRUR0REFKKRwLP94SaOGxER5cpUABERUYQEXkREFCGBFxERRUjgRUREERJ4ERFRhAReREQUIYEXERFFSOBFREQREngREVGEBF5ERBQhgRcREUVI4EVERBEanw/vTEj6O53ZFSbqYuD5SdjPVJC+tlP62l4l9Xey+nq57dmnazSlAm+ySBocz2SBbZC+tlP62l4l9bfXfc0lzYiIKEICLyIiilBq4N3ddAE9lL62U/raXiX1t6d9LfIeXkRElKfUM7yIiChMsYEn6WuS9kkakrRD0uubrqkukr4l6WDV359IurDpmuoi6SOSnpQ0LKmVI90kLZP0lKRDkm5rup66SNoo6TlJTzRdS90kzZX0sKT91c/vzU3XVBdJMyT9XtJjVV/v6NmxS72kKek1tl+slj8DvNX2mobLqoWka4Bf2j4h6RsAtm9tuKxaSHoLMAz8APic7cGGS5pUkqYDfwDeBzwD7AE+Znt/o4XVQNLVwHHgHtsLmq6nTpIuBS61/aikC4C9wIda+n0VMNP2cUnnAo8AN9veXfexiz3DGwm7ykygtclve4ftE9XqbmBOk/XUyfYB2081XUeNFgGHbP/Z9ivAj4EVDddUC9u/Bo42XUcv2P6L7Uer5ZeAA8BlzVZVD3ccr1bPrb568vlbbOABSFov6Wng48CXm66nRz4JbG+6iPi/XQY83bX+DC39YCyVpDcA7wB+12wl9ZE0XdIQ8BzwC9s96WurA0/STklPnORrBYDttbbnApuATzdb7cScrq9Vm7XACTr9nbLG09eIqUjSq4EtwC2jrkK1iu3/2L6CztWmRZJ6csn6nF4cpCm2l46z6SZgG7CuxnJqdbq+SvoE8EHgvZ7iN27P4PvaRs8Cc7vW51SvxRRX3c/aAmyyfX/T9fSC7WOSHgaWAbUPTmr1Gd6pSBroWl0BHGyqlrpJWgZ8HrjW9r+bricmZA8wIGmepPOA64CtDdcUE1QN5NgAHLD9nabrqZOk2SMjxSX10RmA1ZPP35JHaW4B5tMZ0XcEWGO7lX8pSzoEvAr4R/XS7haPSF0JfA+YDRwDhmy/v9mqJpekDwB3AdOBjbbXN1xSLSTdCyyh80T9vwHrbG9otKiaSLoK+A3wOJ3PJIDbbW9rrqp6SHo78CM6P7/TgPtsf7Unxy418CIioizFXtKMiIiyJPAiIqIICbyIiChCAi8iIoqQwIuIiCIk8CJqIOmiaiaOIUl/lfRs1/p5k3ic1ZLuOoP209o8w0LEqeTfEiJqJukrwHHb3x71uuj8Dg6fdMPx7Xs1sMD2LeNsfw7wvO3WThEVMZac4UX0kKQ3VXOebQKeBOZKOtb1/nWSflgtXyLpfkmD1fxhV46x28sl/UrSHyV9sWtfN1TbDUn6vqRpwJ3ABdVr91TtfiZpbzU32eraOh/RsFY/SzPiLPVm4Hrbg9UZ11i+C3zT9u7qCfoPAid7yO6i6vVXgD2SHqTzkPCVwOJqHsS76TyG7DZgdfXg3hE32D4q6XxgUNIW2y9MsI8RZ50EXkTv/WmcE9MuBeZ3rnwC8FpJfbZfHtXuoZGAkvRT4Co6v9vvohNgAH3877RC3T4r6dpqeQ7wRqBVE+dGQAIvogn/6loeBtS1PqNrWcCiaqLXUxl9I97Vthttf6n7jdFnlJKWAlcDV9p+WdIjo2qIaI3cw4toUDVg5QVJA9U9tpVdb+8EbhpZkXTF6O0r10i6sLokuQLYVW37UUkXV9teJKnf7sx83xV8s4CjVdi9jc5ZYUQrJfAimncr8BDwWzozmI+4CXi3pH2S9gOfGmP7PcADwGPAvbaHbD8O3AHslLQP2AFcUrXfAOyrBq38HDi/2v/XafEs2xH5t4SIiChCzvAiIqIICbyIiChCAi8iIoqQwIuIiCIk8CIioggJvIiIKEICLyIiipDAi4iIIvwXWuZY0SMDM5QAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure(figsize=(7, 7))\n", "plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", "plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n", "plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')\n", "plot_scale = 3\n", "plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", "xlabel('True beta')\n", "ylabel('Estimated beta')\n", "legend(loc='best')" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "_bXdOlvUEJl0" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "", "kind": "local" }, "name": "vmapped log-probs.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.10" } }, "nbformat": 4, "nbformat_minor": 1 }