{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# You don't know JAX\n", "\n", "This brief tutorial covers the basics of [JAX](https://github.com/google/jax/). JAX is a Python library which augments `numpy` and Python code with *function transformations* which make it trivial to perform operations common in machine learning programs. Concretely, this makes it simple to write standard Python/`numpy` code and immediately be able to\n", "\n", "- Compute the derivative of a function via a successor to [autograd](https://github.com/HIPS/autograd/)\n", "- Just-in-time compile a function to run efficiently on an accelerator via [XLA](https://www.tensorflow.org/xla/)\n", "- Automagically vectorize a function, so that e.g. you can process a \"batch\" of data in parallel\n", "\n", "In this tutorial, we'll cover each of these transformations in turn by demonstrating their use on one of the core problems of AGI: learning the Exclusive OR (XOR) function with a neural network.\n", "\n", "## JAX is just `numpy` (mostly)\n", "\n", "At its core, you can think of JAX as augmenting `numpy` with the machinery required to perform the aforementioned transformations. JAX's augmented numpy lives at `jax.numpy`. With a few exceptions, you can think of `jax.numpy` as directly interchangeable with `numpy`. As a general rule, you should use `jax.numpy` whenever you plan to use any of JAX's transformations (like computing gradients or just-in-time compiling code) and whenever you want the code to run on an accelerator. You only ever *need* to use `numpy` when you're computing something which is not supported by `jax.numpy`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import random\n", "import itertools\n", "\n", "import jax\n", "import jax.numpy as np\n", "# Current convention is to import original numpy as \"onp\"\n", "import numpy as onp\n", "\n", "from __future__ import print_function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Background\n", "\n", "As previously mentioned, we will be learning the XOR function with a small neural network. The XOR function takes as input two binary numbers and outputs a binary number, like so:\n", "\n", "| In 1 | In 2 | Out |\n", "|-------------------|\n", "| 0 | 0 | 0 |\n", "| 0 | 1 | 1 |\n", "| 1 | 0 | 1 |\n", "| 1 | 1 | 0 |\n", "\n", "We'll use a neural network with a single hidden layer with 3 neurons and a hyperbolic tangent nonlinearity, trained with the cross-entropy loss via stochastic gradient descent. Let's implement this model and loss function. Note that the code is exactly as you'd write in standard `numpy`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Sigmoid nonlinearity\n", "def sigmoid(x):\n", " return 1 / (1 + np.exp(-x))\n", "\n", "# Computes our network's output\n", "def net(params, x):\n", " w1, b1, w2, b2 = params\n", " hidden = np.tanh(np.dot(w1, x) + b1)\n", " return sigmoid(np.dot(w2, hidden) + b2)\n", "\n", "# Cross-entropy loss\n", "def loss(params, x, y):\n", " out = net(params, x)\n", " cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)\n", " return cross_entropy\n", "\n", "# Utility function for testing whether the net produces the correct\n", "# output for all possible inputs\n", "def test_all_inputs(inputs, params):\n", " predictions = [int(net(params, inp) > 0.5) for inp in inputs]\n", " for inp, out in zip(inputs, predictions):\n", " print(inp, '->', out)\n", " return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As mentioned above, there are some places where we want to use standard `numpy` rather than `jax.numpy`. One of those places is with parameter initialization. We'd like to initialize our parameters randomly before we train our network, which is not an operation for which we need derivatives or compilation. JAX uses its own `jax.random` library instead of `numpy.random` which provides better support for reproducibility (seeding) across different transformations. Since we don't need to transform the initialization of parameters in any way, it's simplest just to use standard `numpy.random` instead of `jax.random` here." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def initial_params():\n", " return [\n", " onp.random.randn(3, 2), # w1\n", " onp.random.randn(3), # b1\n", " onp.random.randn(3), # w2\n", " onp.random.randn(), #b2\n", " ]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `jax.grad`\n", "\n", "The first transformation we'll use is `jax.grad`. `jax.grad` takes a function and returns a new function which computes the gradient of the original function. By default, the gradient is taken with respect to the first argument; this can be controlled via the `argnums` argument to `jax.grad`. To use gradient descent, we want to be able to compute the gradient of our loss function with respect to our neural network's parameters. For this, we'll simply use `jax.grad(loss)` which will give us a function we can call to get these gradients." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/craffel/Documents/libraries/jax/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU.\n", " warnings.warn('No GPU found, falling back to CPU.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Iteration 0\n", "[0 0] -> 1\n", "[0 1] -> 0\n", "[1 0] -> 1\n", "[1 1] -> 1\n", "Iteration 100\n", "[0 0] -> 0\n", "[0 1] -> 1\n", "[1 0] -> 1\n", "[1 1] -> 0\n" ] } ], "source": [ "loss_grad = jax.grad(loss)\n", "\n", "# Stochastic gradient descent learning rate\n", "learning_rate = 1.\n", "# All possible inputs\n", "inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n", "\n", "# Initialize parameters randomly\n", "params = initial_params()\n", "\n", "for n in itertools.count():\n", " # Grab a single random input\n", " x = inputs[onp.random.choice(inputs.shape[0])]\n", " # Compute the target output\n", " y = onp.bitwise_xor(*x)\n", " # Get the gradient of the loss for this input/output pair\n", " grads = loss_grad(params, x, y)\n", " # Update parameters via gradient descent\n", " params = [param - learning_rate * grad\n", " for param, grad in zip(params, grads)]\n", " # Every 100 iterations, check whether we've solved XOR\n", " if not n % 100:\n", " print('Iteration {}'.format(n))\n", " if test_all_inputs(inputs, params):\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `jax.jit`\n", "\n", "While carefully-written `numpy` code can be reasonably performant, for modern machine learning we want our code to run as fast as possible. This often involves running our code on different \"accelerators\" like GPUs or TPUs. JAX provides a JIT (just-in-time) compiler which takes a standard Python/`numpy` function and compiles it to run efficiently on an accelerator. Compiling a function also avoids the overhead of the Python interpreter, which helps whether or not you're using an accelerator. In total, `jax.jit` can dramatically speed-up your code with essentially no coding overhead - you just ask JAX to compile the function for you. Even our tiny neural network can see a pretty dramatic speedup when using `jax.jit`:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 loops, best of 3: 13.1 ms per loop\n", "1000 loops, best of 3: 862 µs per loop\n" ] } ], "source": [ "# Time the original gradient function\n", "%timeit loss_grad(params, x, y)\n", "loss_grad = jax.jit(jax.grad(loss))\n", "# Run once to trigger JIT compilation\n", "loss_grad(params, x, y)\n", "%timeit loss_grad(params, x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that JAX allows us to aribtrarily chain together transformations - first, we took the gradient of `loss` using `jax.grad`, then we just-in-time compiled it using `jax.jit`. This is one of the things that makes JAX extra powerful -- apart from chaining `jax.jit` and `jax.grad`, we could also e.g. apply `jax.grad` multiple times to get higher-order derivatives. To make sure that training the neural network still works after compilation, let's train it again. Note that the code for training has not changed whatsoever." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 0\n", "[0 0] -> 1\n", "[0 1] -> 1\n", "[1 0] -> 1\n", "[1 1] -> 1\n", "Iteration 100\n", "[0 0] -> 0\n", "[0 1] -> 1\n", "[1 0] -> 1\n", "[1 1] -> 0\n" ] } ], "source": [ "params = initial_params()\n", "\n", "for n in itertools.count():\n", " x = inputs[onp.random.choice(inputs.shape[0])]\n", " y = onp.bitwise_xor(*x)\n", " grads = loss_grad(params, x, y)\n", " params = [param - learning_rate * grad\n", " for param, grad in zip(params, grads)]\n", " if not n % 100:\n", " print('Iteration {}'.format(n))\n", " if test_all_inputs(inputs, params):\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `jax.vmap`\n", "\n", "An astute reader may have noticed that we have been training our neural network on a single example at a time. This is \"true\" stochastic gradient descent; in practice, when training modern machine learning models we perform \"minibatch\" gradient descent where we average the loss gradients over a mini-batch of examples at each step of gradient descent. JAX provides `jax.vmap`, which is a transformation which automatically \"vectorizes\" a function. What this means is that it allows you to compute the output of a function in parallel over some axis of the input. For us, this means we can apply the `jax.vmap` function transformation and immediately get a version of our loss function gradient which is amenable to using a minibatch of examples. \n", "\n", "`jax.vmap` takes in additional arguments:\n", "- `in_axes` is a tuple or integer which tells JAX over which axes the function's arguments should be parallelized. The tuple should have the same length as the number of arguments of the function being `vmap`'d, or should be an integer when there is only one argument. In our example, we'll use `(None, 0, 0)`, meaning \"don't parallelize over the first argument (`params`), and parallelize over the first (zeroth) dimension of the second and third arguments (`x` and `y`)\".\n", "- `out_axes` is analogous to `in_axes`, except it specifies which axes of the function's output to parallelize over. In our case, we'll use `0`, meaning to parallelize over the first (zeroth) dimension of the function's sole output (the loss gradients).\n", "\n", "Note that we will have to change the training code a little bit - we need to grab a batch of data instead of a single example at a time, and we need to average the gradients over the batch before applying them to update the parameters." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 0\n", "[0 0] -> 0\n", "[0 1] -> 0\n", "[1 0] -> 0\n", "[1 1] -> 0\n", "Iteration 100\n", "[0 0] -> 0\n", "[0 1] -> 1\n", "[1 0] -> 1\n", "[1 1] -> 0\n" ] } ], "source": [ "loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))\n", "\n", "params = initial_params()\n", "\n", "batch_size = 100\n", "\n", "for n in itertools.count():\n", " # Generate a batch of inputs\n", " x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]\n", " y = onp.bitwise_xor(x[:, 0], x[:, 1])\n", " # The call to loss_grad remains the same!\n", " grads = loss_grad(params, x, y)\n", " # Note that we now need to average gradients over the batch\n", " params = [param - learning_rate * np.mean(grad, axis=0)\n", " for param, grad in zip(params, grads)]\n", " if not n % 100:\n", " print('Iteration {}'.format(n))\n", " if test_all_inputs(inputs, params):\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pointers\n", "\n", "That's all we'll be covering in this short tutorial, but this actually covers a great deal of JAX. Since JAX is mostly `numpy` and Python, you can leverage your existing knowledge instead of having to learn a fundamentally new framework or paradigm. For additional resources, check the [notebooks](https://github.com/google/jax/tree/master/notebooks) and [examples](https://github.com/google/jax/tree/master/examples) directories on [JAX's GitHub](https://github.com/google/jax)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 2 }