# Benchmark NumPyro in large dataset

This notebook uses `numpyro` and replicates experiments in references [1] which evaluates the performance of NUTS on various frameworks. The benchmark is run with CUDA 10.1 on a NVIDIA RTX 2070.

In [None]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

In [1]:
import time

import numpy as np

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import HMC, MCMC, NUTS

assert numpyro.__version__.startswith("0.8.0")

# NB: replace gpu by cpu to run this notebook in cpu
numpyro.set_platform("gpu")

We do preprocessing steps as in [source code](https://github.com/google-research/google-research/blob/master/simple_probabilistic_programming/no_u_turn_sampler/logistic_regression.py) of reference [1]:

In [2]:
_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()

# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

# make binary feature
_, counts = np.unique(labels, return_counts=True)
specific_category = jnp.argmax(counts)
labels = labels == specific_category

N, dim = features.shape
print("Data shape:", features.shape)
print(
    "Label distribution: {} has label 1, {} has label 0".format(
        labels.sum(), N - labels.sum()
    )
)

Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip.
Download complete.
Data shape: (581012, 55)
Label distribution: 211840 has label 1, 369172 has label 0


Now, we construct the model:

In [3]:
def model(data, labels):
    coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
    logits = jnp.dot(data, coefs)
    return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

## Benchmark HMC

In [4]:
step_size = jnp.sqrt(0.5 / N)
kernel = HMC(
    model,
    step_size=step_size,
    trajectory_length=(10 * step_size),
    adapt_step_size=False,
)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, progress_bar=False)
mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=("num_steps",))
mcmc.get_extra_fields()["num_steps"].sum().copy()
tic = time.time()
mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=["num_steps"])
num_leapfrogs = mcmc.get_extra_fields()["num_steps"].sum().copy()
toc = time.time()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (toc - tic) / num_leapfrogs)
mcmc.print_summary()

number of leapfrog steps: 5000
avg. time for each step : 0.0015881952285766601

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  coefs[0]      1.99      0.00      1.99      1.98      1.99      4.53      1.49
  coefs[1]     -0.03      0.00     -0.03     -0.03     -0.03      4.26      1.49
  coefs[2]     -0.12      0.00     -0.12     -0.12     -0.12      5.57      1.10
  coefs[3]     -0.29      0.00     -0.29     -0.29     -0.29      4.77      1.40
  coefs[4]     -0.09      0.00     -0.09     -0.10     -0.09      5.13      1.04
  coefs[5]     -0.15      0.00     -0.15     -0.15     -0.15      2.61      3.11
  coefs[6]     -0.02      0.00     -0.02     -0.02     -0.02      2.68      2.54
  coefs[7]     -0.50      0.00     -0.50     -0.50     -0.50     11.32      1.00
  coefs[8]      0.27      0.00      0.27      0.27      0.27      3.25      2.03
  coefs[9]     -0.02      0.00     -0.02     -0.02     -0.02      6.34      1.42
 coefs[10]     -0.23      0.0

In CPU, we get `avg. time for each step : 0.02782863507270813`.

## Benchmark NUTS

In [5]:
mcmc = MCMC(NUTS(model), num_warmup=50, num_samples=50, progress_bar=False)
mcmc.warmup(random.PRNGKey(2019), features, labels, extra_fields=("num_steps",))
mcmc.get_extra_fields()["num_steps"].sum().copy()
tic = time.time()
mcmc.run(random.PRNGKey(2020), features, labels, extra_fields=["num_steps"])
num_leapfrogs = mcmc.get_extra_fields()["num_steps"].sum().copy()
toc = time.time()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (toc - tic) / num_leapfrogs)
mcmc.print_summary()

number of leapfrog steps: 47406
avg. time for each step : 0.0022662237908313812

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  coefs[0]      1.97      0.01      1.97      1.95      1.98     74.56      1.05
  coefs[1]     -0.04      0.00     -0.04     -0.05     -0.03     59.26      0.99
  coefs[2]     -0.07      0.01     -0.06     -0.08     -0.05     35.80      1.12
  coefs[3]     -0.30      0.00     -0.30     -0.31     -0.29     54.31      1.00
  coefs[4]     -0.09      0.00     -0.09     -0.10     -0.09     38.45      0.99
  coefs[5]     -0.14      0.00     -0.14     -0.15     -0.14     26.25      1.12
  coefs[6]      0.23      0.04      0.24      0.19      0.30     11.98      1.18
  coefs[7]     -0.65      0.02     -0.65     -0.69     -0.62     17.16      1.16
  coefs[8]      0.57      0.04      0.57      0.48      0.62     12.71      1.18
  coefs[9]     -0.01      0.00     -0.01     -0.02     -0.01     58.92      0.99
 coefs[10]      0.71      0.

In CPU, we get `avg. time for each step : 0.028006251705287415`.

## Compare to other frameworks

|               |    HMC    |    NUTS   |
| ------------- |----------:|----------:|
| Edward2 (CPU) |           |  56.1 ms  |
| Edward2 (GPU) |           |   9.4 ms  |
| Pyro (CPU)    |  35.4 ms  |  35.3 ms  |
| Pyro (GPU)    |   3.5 ms  |   4.2 ms  |
| NumPyro (CPU) |  27.8 ms  |  28.0 ms  |
| NumPyro (GPU) |   1.6 ms  |   2.2 ms  |

Note that in some situtation, HMC is slower than NUTS. The reason is the number of leapfrog steps in each HMC trajectory is fixed to $10$, while it is not fixed in NUTS.

**Some takeaways:**
+ The overhead of iterative NUTS is pretty small. So most of computation time is indeed spent for evaluating potential function and its gradient.
+ GPU outperforms CPU by a large margin. The data is large, so evaluating potential function in GPU is clearly faster than doing so in CPU.

## References

1. `Simple, Distributed, and Accelerated Probabilistic Programming,` [arxiv](https://arxiv.org/abs/1811.02091)<br>
Dustin Tran, Matthew D. Hoffman, Dave Moore, Christopher Suter, Srinivas Vasudevan, Alexey Radul, Matthew Johnson, Rif A. Saurous