{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6umP1IKf4Dg6" }, "source": [ "# Autobatching log-densities example\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", "\n", "Inspired by a notebook by @davmre." ] }, { "metadata": { "colab_type": "code", "id": "PaW85yP_BrCF", "colab": {} }, "cell_type": "code", "source": [ "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.11-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade -q jax" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "8RZDkfbV3zdR" }, "outputs": [], "source": [ "from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "\n", "import functools\n", "import itertools\n", "import re\n", "import sys\n", "import time\n", "\n", "from matplotlib.pyplot import *\n", "\n", "import jax\n", "\n", "from jax import lax\n", "from jax import numpy as np\n", "from jax import scipy\n", "from jax import random\n", "\n", "import numpy as onp\n", "import scipy as oscipy" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "p2VcZS1d34C6" }, "source": [ "# Generate a fake binary classification dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "pq41hMvn4c_i" }, "outputs": [], "source": [ "onp.random.seed(10009)\n", "\n", "num_features = 10\n", "num_points = 100\n", "\n", "true_beta = onp.random.randn(num_features).astype(np.float32)\n", "all_x = onp.random.randn(num_points, num_features).astype(np.float32)\n", "y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "height": 102 }, "colab_type": "code", "executionInfo": { "elapsed": 30, "status": "ok", "timestamp": 1549999404494, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "O0nVumAw7IlT", "outputId": "c474098f-4e81-4fc8-ad8f-3ba825409be3" }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,\n", " 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,\n", " 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,\n", " 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,\n", " 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DZRVvhpn5aB1" }, "source": [ "# Write the log-joint function for the model\n", "\n", "We'll write a non-batched version, a manually batched version, and an autobatched version." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "C_mDXInL7nsP" }, "source": [ "## Non-batched" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "ZHyL2sJh5ajG" }, "outputs": [], "source": [ "def log_joint(beta):\n", " result = 0.\n", " # Note that no `axis` parameter is provided to `np.sum`.\n", " result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))\n", " result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n", " return result" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 3383, "status": "ok", "timestamp": 1549999409301, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "e51qW0ro6J7C", "outputId": "c778d4fc-85b9-4fea-9875-0d0a3397a027" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU.\n", " warnings.warn('No GPU found, falling back to CPU.')\n" ] }, { "data": { "text/plain": [ "array(-213.23558, dtype=float32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_joint(onp.random.randn(num_features))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "height": 895 }, "colab_type": "code", "executionInfo": { "elapsed": 4130, "status": "error", "timestamp": 1549999413496, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "fglQXK1Y6wnm", "outputId": "cf85d9a7-b403-4e75-efb6-b9d057e66f3c" }, "outputs": [ { "ename": "ValueError", "evalue": "Incompatible shapes for broadcasting: ((100, 10), (1, 100))", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mbatched_test_beta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mlog_joint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mlog_joint\u001b[0;34m(beta)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Note that no `axis` parameter is provided to `np.sum`.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbeta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc\u001b[0m in \u001b[0;36m\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlax_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0m_promote_args_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlax_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0m_promote_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_fn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_wraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumpy_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc\u001b[0m in \u001b[0;36m_promote_args\u001b[0;34m(fun_name, *args)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;34m\"\"\"Convenience function to apply Numpy argument shape and dtype promotion.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0m_check_arraylike\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_promote_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0m_promote_dtypes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc\u001b[0m in \u001b[0;36m_promote_shapes\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0mshapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0mnd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbroadcast_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)\n\u001b[1;32m 141\u001b[0m if len(shp) != nd else arg for arg, shp in zip(args, shapes)]\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/util.pyc\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpopitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlast\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmemoized_fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lax.pyc\u001b[0m in \u001b[0;36mbroadcast_shapes\u001b[0;34m(*shapes)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0monp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mresult_shape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mshapes\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m raise ValueError(\"Incompatible shapes for broadcasting: {}\"\n\u001b[0;32m---> 69\u001b[0;31m .format(tuple(map(tuple, shapes))))\n\u001b[0m\u001b[1;32m 70\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: Incompatible shapes for broadcasting: ((100, 10), (1, 100))" ] } ], "source": [ "# This doesn't work, because we didn't write `log_prob()` to handle batching.\n", "batch_size = 10\n", "batched_test_beta = onp.random.randn(batch_size, num_features)\n", "\n", "log_joint(onp.random.randn(batch_size, num_features))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_lQ8MnKq7sLU" }, "source": [ "## Manually batched" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "2g5-4bQE7gRA" }, "outputs": [], "source": [ "def batched_log_joint(beta):\n", " result = 0.\n", " # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis\n", " # or setting it incorrectly yields an error; at worst, it silently changes the\n", " # semantics of the model.\n", " result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),\n", " axis=-1)\n", " # Note the multiple transposes. Getting this right is not rocket science,\n", " # but it's also not totally mindless. (I didn't get it right on the first\n", " # try.)\n", " result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),\n", " axis=-1)\n", " return result" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "height": 68 }, "colab_type": "code", "executionInfo": { "elapsed": 735, "status": "ok", "timestamp": 1549999417264, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "KdDMr-Gy85CO", "outputId": "1e90fc29-60fb-4460-f08f-2dd486cc8f5e" }, "outputs": [ { "data": { "text/plain": [ "array([-147.84033 , -207.02205 , -109.26075 , -243.8083 , -163.02911 ,\n", " -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ],\n", " dtype=float32)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 10\n", "batched_test_beta = onp.random.randn(batch_size, num_features)\n", "\n", "batched_log_joint(batched_test_beta)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-uuGlHQ_85kd" }, "source": [ "## Autobatched with vmap\n", "\n", "It just works." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "height": 68 }, "colab_type": "code", "executionInfo": { "elapsed": 174, "status": "ok", "timestamp": 1549999417694, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "SU20bouH8-Za", "outputId": "5637b58a-0d7e-4a61-b74a-f4d2cab2105a" }, "outputs": [ { "data": { "text/plain": [ "array([-147.84033 , -207.02205 , -109.26075 , -243.8083 , -163.02911 ,\n", " -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ],\n", " dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vmap_batched_log_joint = jax.vmap(log_joint)\n", "vmap_batched_log_joint(batched_test_beta)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "L1KNBo9y_yZJ" }, "source": [ "# Self-contained variational inference example\n", "\n", "A little code is copied from above." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "lQTPaaQMJh8Y" }, "source": [ "## Set up the (batched) log-joint function" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "AITXbaofA3Pm" }, "outputs": [], "source": [ "@jax.jit\n", "def log_joint(beta):\n", " result = 0.\n", " # Note that no `axis` parameter is provided to `np.sum`.\n", " result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))\n", " result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))\n", " return result\n", "\n", "batched_log_joint = jax.jit(jax.vmap(log_joint))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UmmFMQ8LJk6a" }, "source": [ "## Define the ELBO and its gradient" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": {}, "colab_type": "code", "id": "MJtnskL6BKwV" }, "outputs": [], "source": [ "def elbo(beta_loc, beta_log_scale, epsilon):\n", " beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon\n", " return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))\n", "elbo = jax.jit(elbo, static_argnums=(2, 3))\n", "elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oQC7xKYnJrp5" }, "source": [ "## Optimize the ELBO using SGD" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "height": 1717 }, "colab_type": "code", "executionInfo": { "elapsed": 2986, "status": "ok", "timestamp": 1549999510348, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "9JrD5nNgH715", "outputId": "1b7949cc-1296-46bb-9d88-412475834944" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\t-180.853881836\n", "10\t-113.060455322\n", "20\t-102.737258911\n", "30\t-99.7873535156\n", "40\t-98.9089889526\n", "50\t-98.297454834\n", "60\t-98.1863174438\n", "70\t-97.5797195435\n", "80\t-97.2860031128\n", "90\t-97.4699630737\n", "100\t-97.4771728516\n", "110\t-97.5806732178\n", "120\t-97.494354248\n", "130\t-97.5027313232\n", "140\t-96.8639526367\n", "150\t-97.4419784546\n", "160\t-97.0694046021\n", "170\t-96.8402862549\n", "180\t-97.2133789062\n", "190\t-97.5650253296\n", "200\t-97.2639770508\n", "210\t-97.1197967529\n", "220\t-97.395942688\n", "230\t-97.1683197021\n", "240\t-97.1184082031\n", "250\t-97.2434539795\n", "260\t-97.2978668213\n", "270\t-96.692855835\n", "280\t-96.9643859863\n", "290\t-97.3005523682\n", "300\t-96.6359176636\n", "310\t-97.0351867676\n", "320\t-97.529083252\n", "330\t-97.2881164551\n", "340\t-97.0732192993\n", "350\t-97.1561889648\n", "360\t-97.2588195801\n", "370\t-97.1951446533\n", "380\t-97.1309204102\n", "390\t-97.1172637939\n", "400\t-96.9387359619\n", "410\t-97.2667694092\n", "420\t-97.353225708\n", "430\t-97.2100753784\n", "440\t-97.2843475342\n", "450\t-97.1630859375\n", "460\t-97.2612457275\n", "470\t-97.2134399414\n", "480\t-97.2399749756\n", "490\t-97.1491317749\n", "500\t-97.2352828979\n", "510\t-96.9342041016\n", "520\t-97.212097168\n", "530\t-96.8257751465\n", "540\t-97.0128479004\n", "550\t-96.9417648315\n", "560\t-97.1652069092\n", "570\t-97.2916564941\n", "580\t-97.429397583\n", "590\t-97.2437133789\n", "600\t-97.1521911621\n", "610\t-97.4984436035\n", "620\t-96.9906997681\n", "630\t-96.8895645142\n", "640\t-96.8996887207\n", "650\t-97.1379394531\n", "660\t-97.4370574951\n", "670\t-96.9923629761\n", "680\t-97.1562423706\n", "690\t-97.1869049072\n", "700\t-97.1116027832\n", "710\t-97.7810516357\n", "720\t-97.2322616577\n", "730\t-97.1620635986\n", "740\t-96.9958190918\n", "750\t-96.6672210693\n", "760\t-97.1679534912\n", "770\t-97.5143508911\n", "780\t-97.2890090942\n", "790\t-96.9122619629\n", "800\t-97.1709976196\n", "810\t-97.290473938\n", "820\t-97.1624298096\n", "830\t-97.1910629272\n", "840\t-97.5638198853\n", "850\t-97.0019378662\n", "860\t-96.8655548096\n", "870\t-96.7633743286\n", "880\t-96.8366088867\n", "890\t-97.1217956543\n", "900\t-97.0955505371\n", "910\t-97.0682373047\n", "920\t-97.1194763184\n", "930\t-96.8792953491\n", "940\t-97.4562530518\n", "950\t-96.6928024292\n", "960\t-97.293762207\n", "970\t-97.3353042603\n", "980\t-97.349609375\n", "990\t-97.0967559814\n" ] } ], "source": [ "def normal_sample(key, shape):\n", " \"\"\"Convenience function for quasi-stateful RNG.\"\"\"\n", " new_key, sub_key = random.split(key)\n", " return new_key, random.normal(sub_key, shape)\n", "normal_sample = jax.jit(normal_sample, static_argnums=(1,))\n", "\n", "key = random.PRNGKey(10003)\n", "\n", "beta_loc = np.zeros(num_features, np.float32)\n", "beta_log_scale = np.zeros(num_features, np.float32)\n", "\n", "step_size = 0.01\n", "batch_size = 128\n", "epsilon_shape = (batch_size, num_features)\n", "for i in range(1000):\n", " key, epsilon = normal_sample(key, epsilon_shape)\n", " elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(\n", " beta_loc, beta_log_scale, epsilon)\n", " beta_loc += step_size * beta_loc_grad\n", " beta_log_scale += step_size * beta_log_scale_grad\n", " if i % 10 == 0:\n", " print('{}\\t{}'.format(i, elbo_val))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "b3ZAe5fJJ2KM" }, "source": [ "## Display the results\n", "\n", "Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "height": 481 }, "colab_type": "code", "executionInfo": { "elapsed": 263, "status": "ok", "timestamp": 1549999510632, "user": { "displayName": "Matt Hoffman", "photoUrl": "https://lh3.googleusercontent.com/-r5gqCRwU9kk/AAAAAAAAAAI/AAAAAAAAALw/T9KGDIrA_iA/s64/photo.jpg", "userId": "11857134876214181812" }, "user_tz": 480 }, "id": "zt1NBLoVHtOG", "outputId": "2f0081cf-bbfe-426c-bc5e-a1c09468234a" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "figure(figsize=(7, 7))\n", "plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", "plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n", "plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')\n", "plot_scale = 3\n", "plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", "xlabel('True beta')\n", "ylabel('Estimated beta')\n", "legend(loc='best')" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "_bXdOlvUEJl0" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "", "kind": "local" }, "name": "vmapped log-probs.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.10" } }, "nbformat": 4, "nbformat_minor": 1 }