{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# JAX\n", "\n", "Let's look at a problem that looks just a bit like machine learning: Curve fitting for unbinned data. We are going to ignore the actual minimizer, and instead just compute the negative log likelihood (nll)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "source_hidden": true }, "tags": [] }, "outputs": [], "source": [ "# from jax.config import config\n", "# config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "np.random.seed(42)\n", "\n", "dist = np.hstack(\n", " [\n", " np.random.normal(loc=1, scale=2.0, size=1_000_000),\n", " np.random.normal(loc=1, scale=0.5, size=1_000_000),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with NumPy, just to show how it would be done:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gaussian(x, μ, σ):\n", " return 1 / np.sqrt(2 * np.pi * σ**2) * np.exp(-((x - μ) ** 2) / (2 * σ**2))\n", "\n", "\n", "def add(x, f_0, μ, σ, σ2):\n", " return f_0 * gaussian(x, μ, σ) + (1 - f_0) * gaussian(x, μ, σ2)\n", "\n", "\n", "def nll(x, f_0, μ, σ, σ2):\n", " return -np.sum(np.log(add(x, f_0, μ, σ, σ2)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "nll(dist, *np.random.rand(4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%timeit\n", "nll(dist, *np.random.rand(4))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Jax\n", "\n", "Jax is a tool from Google. It can target a wide variety of backends (CPU, GPU, TPU), can JIT compile, and can take gradients. It is _very_ powerful, and rather tricky, since it does quite a few things a bit differently. First let's try using it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll just replace `np` with `jnp` everywhere in the above code, to produce:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gaussian(x, μ, σ):\n", " return 1 / jnp.sqrt(2 * jnp.pi * σ**2) * jnp.exp(-((x - μ) ** 2) / (2 * σ**2))\n", "\n", "\n", "def add(x, f_0, μ, σ, σ2):\n", " return f_0 * gaussian(x, μ, σ) + (1 - f_0) * gaussian(x, μ, σ2)\n", "\n", "\n", "def nll(x, f_0, μ, σ, σ2):\n", " return -jnp.sum(jnp.log(add(x, f_0, μ, σ, σ2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we need just one more step - we need Jax arrays instead of NumPy arrays:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "d_dist = jnp.asarray(dist)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There's one more step, but let's just check this first:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "nll(d_dist, *np.random.rand(4)).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%timeit\n", "nll(d_dist, *np.random.rand(4)).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We probably are seeing a nice speedup here. File it away - we'll explain it later, and let's move on.\n", "\n", "Now we can JIT our function. Unlike numba, we just pass the top level function in." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nll_jit = jax.jit(nll)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now the first time we call it, JAX will \"trace\" the function and produce the XLA code for it. Like other tracers, it can't handle non-vectorized control flow." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "nll_jit(d_dist, *np.random.rand(4)).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that it's primed, let's measure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%timeit\n", "nll_jit(d_dist, *np.random.rand(4)).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is very nice, but there is a caveat; this is in 32 bit mode. Uncomment the code at the top and _restart_ the kernel; compare the timings again." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Further reading:\n", "\n", "* [CompClass: Fitting](https://github.com/henryiii/compclass/blob/master/classes/week12/1_fitting.ipynb)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:performance-minicourse] *", "language": "python", "name": "conda-env-performance-minicourse-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 4 }