{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Benchmark NumPyro in large dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook uses `numpyro` and replicates experiments in references [1] which evaluates the performance of NUTS on various frameworks. The benchmark is run with CUDA 10.1 on a NVIDIA RTX 2070." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import numpy as np\n", "\n", "import jax.numpy as jnp\n", "from jax import random\n", "\n", "import numpyro\n", "import numpyro.distributions as dist\n", "from numpyro.examples.datasets import COVTYPE, load_dataset\n", "from numpyro.infer import HMC, MCMC, NUTS\n", "\n", "assert numpyro.__version__.startswith(\"0.8.0\")\n", "\n", "# NB: replace gpu by cpu to run this notebook in cpu\n", "numpyro.set_platform(\"gpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We do preprocessing steps as in [source code](https://github.com/google-research/google-research/blob/master/simple_probabilistic_programming/no_u_turn_sampler/logistic_regression.py) of reference [1]:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip.\n", "Download complete.\n", "Data shape: (581012, 55)\n", "Label distribution: 211840 has label 1, 369172 has label 0\n" ] } ], "source": [ "_, fetch = load_dataset(COVTYPE, shuffle=False)\n", "features, labels = fetch()\n", "\n", "# normalize features and add intercept\n", "features = (features - features.mean(0)) / features.std(0)\n", "features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])\n", "\n", "# make binary feature\n", "_, counts = np.unique(labels, return_counts=True)\n", "specific_category = jnp.argmax(counts)\n", "labels = labels == specific_category\n", "\n", "N, dim = features.shape\n", "print(\"Data shape:\", features.shape)\n", "print(\n", " \"Label distribution: {} has label 1, {} has label 0\".format(\n", " labels.sum(), N - labels.sum()\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we construct the model:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def model(data, labels):\n", " coefs = numpyro.sample(\"coefs\", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))\n", " logits = jnp.dot(data, coefs)\n", " return numpyro.sample(\"obs\", dist.Bernoulli(logits=logits), obs=labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark HMC" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of leapfrog steps: 5000\n", "avg. time for each step : 0.0015881952285766601\n", "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " coefs[0] 1.99 0.00 1.99 1.98 1.99 4.53 1.49\n", " coefs[1] -0.03 0.00 -0.03 -0.03 -0.03 4.26 1.49\n", " coefs[2] -0.12 0.00 -0.12 -0.12 -0.12 5.57 1.10\n", " coefs[3] -0.29 0.00 -0.29 -0.29 -0.29 4.77 1.40\n", " coefs[4] -0.09 0.00 -0.09 -0.10 -0.09 5.13 1.04\n", " coefs[5] -0.15 0.00 -0.15 -0.15 -0.15 2.61 3.11\n", " coefs[6] -0.02 0.00 -0.02 -0.02 -0.02 2.68 2.54\n", " coefs[7] -0.50 0.00 -0.50 -0.50 -0.50 11.32 1.00\n", " coefs[8] 0.27 0.00 0.27 0.27 0.27 3.25 2.03\n", " coefs[9] -0.02 0.00 -0.02 -0.02 -0.02 6.34 1.42\n", " coefs[10] -0.23 0.00 -0.23 -0.23 -0.22 3.76 1.50\n", " coefs[11] -0.31 0.00 -0.31 -0.31 -0.31 3.51 1.40\n", " coefs[12] -0.54 0.00 -0.54 -0.54 -0.54 2.64 2.52\n", " coefs[13] -1.94 0.00 -1.94 -1.94 -1.93 2.54 2.75\n", " coefs[14] 0.24 0.00 0.24 0.24 0.24 9.69 1.08\n", " coefs[15] -1.07 0.00 -1.07 -1.07 -1.07 3.85 1.85\n", " coefs[16] -1.26 0.00 -1.26 -1.26 -1.26 5.80 1.07\n", " coefs[17] -0.22 0.00 -0.22 -0.22 -0.22 4.45 1.33\n", " coefs[18] -0.08 0.00 -0.08 -0.08 -0.08 2.45 2.88\n", " coefs[19] -0.68 0.00 -0.68 -0.69 -0.68 2.72 2.12\n", " coefs[20] -0.13 0.00 -0.13 -0.13 -0.13 2.79 2.30\n", " coefs[21] -0.02 0.00 -0.02 -0.02 -0.02 8.65 1.15\n", " coefs[22] 0.02 0.00 0.02 0.02 0.02 2.73 2.32\n", " coefs[23] -0.15 0.00 -0.15 -0.15 -0.15 2.75 2.56\n", " coefs[24] -0.12 0.00 -0.12 -0.12 -0.12 3.92 1.31\n", " coefs[25] -0.32 0.00 -0.32 -0.32 -0.32 5.25 1.31\n", " coefs[26] -0.17 0.00 -0.17 -0.17 -0.17 4.08 1.13\n", " coefs[27] -1.19 0.00 -1.19 -1.19 -1.19 3.22 1.85\n", " coefs[28] -0.05 0.00 -0.05 -0.05 -0.05 7.87 1.01\n", " coefs[29] -0.03 0.00 -0.03 -0.03 -0.03 7.36 1.17\n", " coefs[30] -0.04 0.00 -0.04 -0.04 -0.04 2.88 2.06\n", " coefs[31] -0.06 0.00 -0.06 -0.06 -0.06 6.43 1.23\n", " coefs[32] -0.02 0.00 -0.02 -0.02 -0.02 6.80 1.03\n", " coefs[33] -0.03 0.00 -0.03 -0.03 -0.03 6.47 1.26\n", " coefs[34] 0.11 0.00 0.11 0.10 0.11 6.67 1.22\n", " coefs[35] 0.08 0.00 0.08 0.08 0.08 2.49 2.80\n", " coefs[36] -0.00 0.00 -0.00 -0.00 -0.00 6.23 1.31\n", " coefs[37] -0.07 0.00 -0.07 -0.07 -0.07 2.72 2.36\n", " coefs[38] -0.03 0.00 -0.03 -0.03 -0.03 3.97 1.52\n", " coefs[39] -0.06 0.00 -0.06 -0.06 -0.06 6.16 1.26\n", " coefs[40] -0.01 0.00 -0.01 -0.01 -0.01 2.86 2.07\n", " coefs[41] -0.06 0.00 -0.06 -0.06 -0.06 3.02 1.98\n", " coefs[42] -0.39 0.00 -0.39 -0.40 -0.39 2.67 2.45\n", " coefs[43] -0.27 0.00 -0.27 -0.27 -0.27 5.15 1.33\n", " coefs[44] -0.07 0.00 -0.07 -0.07 -0.07 5.75 1.30\n", " coefs[45] -0.25 0.00 -0.25 -0.26 -0.25 2.57 2.50\n", " coefs[46] -0.09 0.00 -0.09 -0.09 -0.09 8.72 1.00\n", " coefs[47] -0.12 0.00 -0.12 -0.12 -0.12 3.10 1.73\n", " coefs[48] -0.15 0.00 -0.15 -0.15 -0.15 4.95 1.33\n", " coefs[49] -0.05 0.00 -0.05 -0.05 -0.05 2.99 2.32\n", " coefs[50] -0.94 0.00 -0.94 -0.94 -0.94 10.08 1.00\n", " coefs[51] -0.32 0.00 -0.32 -0.32 -0.32 3.90 1.75\n", " coefs[52] -0.29 0.00 -0.29 -0.30 -0.29 13.85 1.05\n", " coefs[53] -0.31 0.00 -0.31 -0.31 -0.31 8.21 1.01\n", " coefs[54] -1.76 0.00 -1.76 -1.76 -1.76 3.24 1.54\n", "\n", "Number of divergences: 0\n" ] } ], "source": [ "step_size = jnp.sqrt(0.5 / N)\n", "kernel = HMC(\n", " model,\n", " step_size=step_size,\n", " trajectory_length=(10 * step_size),\n", " adapt_step_size=False,\n", ")\n", "mcmc = MCMC(kernel, num_warmup=500, num_samples=500, progress_bar=False)\n", "mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=(\"num_steps\",))\n", "mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n", "tic = time.time()\n", "mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=[\"num_steps\"])\n", "num_leapfrogs = mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n", "toc = time.time()\n", "print(\"number of leapfrog steps:\", num_leapfrogs)\n", "print(\"avg. time for each step :\", (toc - tic) / num_leapfrogs)\n", "mcmc.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In CPU, we get `avg. time for each step : 0.02782863507270813`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmark NUTS" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of leapfrog steps: 47406\n", "avg. time for each step : 0.0022662237908313812\n", "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " coefs[0] 1.97 0.01 1.97 1.95 1.98 74.56 1.05\n", " coefs[1] -0.04 0.00 -0.04 -0.05 -0.03 59.26 0.99\n", " coefs[2] -0.07 0.01 -0.06 -0.08 -0.05 35.80 1.12\n", " coefs[3] -0.30 0.00 -0.30 -0.31 -0.29 54.31 1.00\n", " coefs[4] -0.09 0.00 -0.09 -0.10 -0.09 38.45 0.99\n", " coefs[5] -0.14 0.00 -0.14 -0.15 -0.14 26.25 1.12\n", " coefs[6] 0.23 0.04 0.24 0.19 0.30 11.98 1.18\n", " coefs[7] -0.65 0.02 -0.65 -0.69 -0.62 17.16 1.16\n", " coefs[8] 0.57 0.04 0.57 0.48 0.62 12.71 1.18\n", " coefs[9] -0.01 0.00 -0.01 -0.02 -0.01 58.92 0.99\n", " coefs[10] 0.71 0.84 0.67 -0.76 2.04 7.17 0.98\n", " coefs[11] 0.08 0.38 0.06 -0.57 0.68 7.18 0.98\n", " coefs[12] 0.39 0.84 0.35 -1.09 1.72 7.18 0.98\n", " coefs[13] -1.54 0.53 -1.56 -2.20 -0.65 10.23 0.99\n", " coefs[14] -0.48 0.52 -0.45 -1.25 0.25 16.10 0.98\n", " coefs[15] -1.83 0.31 -1.80 -2.34 -1.48 5.35 0.98\n", " coefs[16] -1.06 0.52 -0.96 -1.88 -0.19 31.52 1.00\n", " coefs[17] -0.17 0.08 -0.15 -0.30 -0.06 15.07 1.38\n", " coefs[18] -0.64 0.64 -0.59 -1.50 0.25 18.98 1.03\n", " coefs[19] -0.74 0.57 -0.71 -1.66 0.07 12.04 1.11\n", " coefs[20] -1.04 0.64 -1.14 -1.80 -0.10 16.18 1.00\n", " coefs[21] -0.01 0.01 -0.01 -0.02 0.01 12.68 1.42\n", " coefs[22] 0.03 0.02 0.04 -0.00 0.07 15.54 1.37\n", " coefs[23] -0.10 0.12 -0.07 -0.27 0.09 15.48 1.39\n", " coefs[24] -0.09 0.08 -0.07 -0.21 0.02 15.48 1.36\n", " coefs[25] -0.26 0.12 -0.24 -0.46 -0.10 15.62 1.37\n", " coefs[26] -0.12 0.09 -0.10 -0.25 0.03 15.71 1.37\n", " coefs[27] -1.11 0.47 -1.11 -1.83 -0.30 17.62 1.08\n", " coefs[28] -0.83 0.70 -0.54 -2.04 0.02 34.06 0.99\n", " coefs[29] -0.01 0.04 0.00 -0.06 0.05 15.94 1.36\n", " coefs[30] -0.02 0.04 -0.00 -0.08 0.04 15.02 1.44\n", " coefs[31] -0.05 0.03 -0.04 -0.09 0.00 16.46 1.28\n", " coefs[32] 0.01 0.04 0.02 -0.06 0.07 15.28 1.36\n", " coefs[33] 0.04 0.07 0.05 -0.06 0.14 15.73 1.37\n", " coefs[34] 0.11 0.02 0.11 0.08 0.14 14.67 1.33\n", " coefs[35] 0.13 0.12 0.16 -0.05 0.32 15.43 1.38\n", " coefs[36] 0.07 0.16 0.11 -0.16 0.32 15.53 1.37\n", " coefs[37] 0.00 0.10 0.02 -0.16 0.14 15.53 1.38\n", " coefs[38] -0.04 0.02 -0.04 -0.06 -0.02 17.43 1.33\n", " coefs[39] -0.05 0.04 -0.04 -0.10 0.01 15.25 1.40\n", " coefs[40] 0.01 0.02 0.02 -0.02 0.05 15.66 1.35\n", " coefs[41] -0.04 0.02 -0.04 -0.08 -0.00 11.32 1.38\n", " coefs[42] -0.31 0.21 -0.26 -0.61 0.03 15.56 1.38\n", " coefs[43] -0.20 0.12 -0.18 -0.40 -0.04 15.60 1.38\n", " coefs[44] -0.01 0.11 0.02 -0.17 0.16 15.52 1.38\n", " coefs[45] -0.15 0.15 -0.11 -0.37 0.09 15.46 1.38\n", " coefs[46] -0.02 0.14 0.00 -0.23 0.20 15.83 1.37\n", " coefs[47] -0.12 0.03 -0.11 -0.16 -0.07 16.20 1.38\n", " coefs[48] -0.12 0.03 -0.12 -0.17 -0.08 16.26 1.36\n", " coefs[49] -0.04 0.01 -0.04 -0.05 -0.03 14.31 1.28\n", " coefs[50] -0.98 0.44 -0.94 -1.71 -0.33 12.09 0.98\n", " coefs[51] -0.26 0.09 -0.24 -0.40 -0.14 15.53 1.38\n", " coefs[52] -0.25 0.08 -0.23 -0.36 -0.12 15.81 1.37\n", " coefs[53] -0.26 0.06 -0.25 -0.36 -0.16 15.99 1.36\n", " coefs[54] -1.98 0.13 -1.96 -2.16 -1.81 44.87 0.98\n", "\n", "Number of divergences: 0\n" ] } ], "source": [ "mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50, progress_bar=False)\n", "mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=(\"num_steps\",))\n", "mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n", "tic = time.time()\n", "mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=[\"num_steps\"])\n", "num_leapfrogs = mcmc.get_extra_fields()[\"num_steps\"].sum().copy()\n", "toc = time.time()\n", "print(\"number of leapfrog steps:\", num_leapfrogs)\n", "print(\"avg. time for each step :\", (toc - tic) / num_leapfrogs)\n", "mcmc.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In CPU, we get `avg. time for each step : 0.028006251705287415`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare to other frameworks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| | HMC | NUTS |\n", "| ------------- |----------:|----------:|\n", "| Edward2 (CPU) | | 56.1 ms |\n", "| Edward2 (GPU) | | 9.4 ms |\n", "| Pyro (CPU) | 35.4 ms | 35.3 ms |\n", "| Pyro (GPU) | 3.5 ms | 4.2 ms |\n", "| NumPyro (CPU) | 27.8 ms | 28.0 ms |\n", "| NumPyro (GPU) | 1.6 ms | 2.2 ms |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in some situtation, HMC is slower than NUTS. The reason is the number of leapfrog steps in each HMC trajectory is fixed to $10$, while it is not fixed in NUTS." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Some takeaways:**\n", "+ The overhead of iterative NUTS is pretty small. So most of computation time is indeed spent for evaluating potential function and its gradient.\n", "+ GPU outperforms CPU by a large margin. The data is large, so evaluating potential function in GPU is clearly faster than doing so in CPU." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "1. `Simple, Distributed, and Accelerated Probabilistic Programming,` [arxiv](https://arxiv.org/abs/1811.02091)
