# JAX

Let's look at a problem that looks just a bit like machine learning: Curve fitting for unbinned data. We are going to ignore the actual minimizer, and instead just compute the negative log likelihood (nll).

In [None]:
# from jax.config import config
# config.update("jax_enable_x64", True)

In [None]:
import numpy as np

np.random.seed(42)

dist = np.hstack(
 [
 np.random.normal(loc=1, scale=2.0, size=1_000_000),
 np.random.normal(loc=1, scale=0.5, size=1_000_000),
 ]
)

Let's start with NumPy, just to show how it would be done:

In [None]:
def gaussian(x, μ, σ):
 return 1 / np.sqrt(2 * np.pi * σ**2) * np.exp(-((x - μ) ** 2) / (2 * σ**2))


def add(x, f_0, μ, σ, σ2):
 return f_0 * gaussian(x, μ, σ) + (1 - f_0) * gaussian(x, μ, σ2)


def nll(x, f_0, μ, σ, σ2):
 return -np.sum(np.log(add(x, f_0, μ, σ, σ2)))

In [None]:
%%time
nll(dist, *np.random.rand(4))

In [None]:
%%timeit
nll(dist, *np.random.rand(4))

## Jax

Jax is a tool from Google. It can target a wide variety of backends (CPU, GPU, TPU), can JIT compile, and can take gradients. It is _very_ powerful, and rather tricky, since it does quite a few things a bit differently. First let's try using it:

In [None]:
import jax
import jax.numpy as jnp

Now we'll just replace `np` with `jnp` everywhere in the above code, to produce:

In [None]:
def gaussian(x, μ, σ):
 return 1 / jnp.sqrt(2 * jnp.pi * σ**2) * jnp.exp(-((x - μ) ** 2) / (2 * σ**2))


def add(x, f_0, μ, σ, σ2):
 return f_0 * gaussian(x, μ, σ) + (1 - f_0) * gaussian(x, μ, σ2)


def nll(x, f_0, μ, σ, σ2):
 return -jnp.sum(jnp.log(add(x, f_0, μ, σ, σ2)))

Now we need just one more step - we need Jax arrays instead of NumPy arrays:

In [None]:
d_dist = jnp.asarray(dist)

There's one more step, but let's just check this first:

In [None]:
%%time
nll(d_dist, *np.random.rand(4)).block_until_ready()

In [None]:
%%timeit
nll(d_dist, *np.random.rand(4)).block_until_ready()

We probably are seeing a nice speedup here. File it away - we'll explain it later, and let's move on.

Now we can JIT our function. Unlike numba, we just pass the top level function in.

In [None]:
nll_jit = jax.jit(nll)

Now the first time we call it, JAX will "trace" the function and produce the XLA code for it. Like other tracers, it can't handle non-vectorized control flow.

In [None]:
%%time
nll_jit(d_dist, *np.random.rand(4)).block_until_ready()

Now that it's primed, let's measure:

In [None]:
%%timeit
nll_jit(d_dist, *np.random.rand(4)).block_until_ready()

This is very nice, but there is a caveat; this is in 32 bit mode. Uncomment the code at the top and _restart_ the kernel; compare the timings again.

#### Further reading:

* [CompClass: Fitting](https://github.com/henryiii/compclass/blob/master/classes/week12/1_fitting.ipynb)