{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Autodiff Cookbook.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" } }, "cells": [ { "metadata": { "colab_type": "code", "id": "pQYNFLfn6zlo", "outputId": "3791adad-0231-449a-9034-4f1bae6f59fa", "colab": { "base_uri": "https://localhost:8080/", "height": 86 } }, "cell_type": "code", "source": [ "!pip install -q --upgrade jax jaxlib" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "text": [ "\u001b[K 100% |████████████████████████████████| 215kB 6.8MB/s \n", "\u001b[K 100% |████████████████████████████████| 21.1MB 1.1MB/s \n", "\u001b[K 100% |████████████████████████████████| 61kB 18.0MB/s \n", "\u001b[?25h Building wheel for opt-einsum (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "code", "id": "JTYyZkSO6vuy", "outputId": "375ffa68-d439-4f1e-ad56-52cec9d2ed3f", "colab": { "base_uri": "https://localhost:8080/", "height": 52 } }, "cell_type": "code", "source": [ "from __future__ import print_function, division\n", "import jax.numpy as np\n", "from jax import grad, jit, vmap\n", "from jax import random\n", "\n", "key = random.PRNGKey(0)" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python2.7/dist-packages/jax/lib/xla_bridge.py:167: UserWarning: No GPU found, falling back to CPU.\n", " warnings.warn('No GPU found, falling back to CPU.')\n" ], "name": "stderr" } ] }, { "metadata": { "colab_type": "text", "id": "Ic1reB4s6vu1" }, "cell_type": "markdown", "source": [ "# The Autodiff Cookbook\n", "\n", "*alexbw@, mattjj@* \n", "\n", "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." ] }, { "metadata": { "colab_type": "text", "id": "YxnjtAGN6vu2" }, "cell_type": "markdown", "source": [ "## Gradients" ] }, { "metadata": { "colab_type": "text", "id": "zqwpfr2vAsvt" }, "cell_type": "markdown", "source": [ "### Starting with `grad`\n", "\n", "You can differentiate a function with `grad`:" ] }, { "metadata": { "colab_type": "code", "id": "0NLO4Wfknzmk", "outputId": "bb76443a-0fc4-4bac-e95c-b1377cc58217", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "cell_type": "code", "source": [ "grad_tanh = grad(np.tanh)\n", "print(grad_tanh(2.0))" ], "execution_count": 6, "outputs": [ { "output_type": "stream", "text": [ "0.070650816\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "LGcNfDntoBZI" }, "cell_type": "markdown", "source": [ "`grad` takes a function and returns a function. If you have a Python function `f` that evaluates the mathematical function $f$, then `grad(f)` is a Python function that evaluates the mathematical function $\\nabla f$. That means `grad(f)(x)` represents the value $\\nabla f(x)$.\n", "\n", "Since `grad` operates on functions, you can apply it to its own output to differentiate as many times as you like:" ] }, { "metadata": { "colab_type": "code", "id": "RDGk1GDsoawu", "outputId": "6a3ea5ea-c29d-4774-9961-9ea5ac5a7fc7", "colab": { "base_uri": "https://localhost:8080/", "height": 52 } }, "cell_type": "code", "source": [ "print(grad(grad(np.tanh))(2.0))\n", "print(grad(grad(grad(np.tanh)))(2.0))" ], "execution_count": 7, "outputs": [ { "output_type": "stream", "text": [ "-0.13621867\n", "0.25265405\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "2rcnpTiinqi8" }, "cell_type": "markdown", "source": [ "Let's look at computing gradients with `grad` in a linear logistic regression model. First, the setup:" ] }, { "metadata": { "colab_type": "code", "id": "27TcOT2i6vu5", "colab": {} }, "cell_type": "code", "source": [ "def sigmoid(x):\n", " return 0.5 * (np.tanh(x / 2) + 1)\n", "\n", "# Outputs probability of a label being true.\n", "def predict(W, b, inputs):\n", " return sigmoid(np.dot(inputs, W) + b)\n", "\n", "# Build a toy dataset.\n", "inputs = np.array([[0.52, 1.12, 0.77],\n", " [0.88, -1.08, 0.15],\n", " [0.52, 0.06, -1.30],\n", " [0.74, -2.49, 1.39]])\n", "targets = np.array([True, True, False, True])\n", "\n", "# Training loss is the negative log-likelihood of the training examples.\n", "def loss(W, b):\n", " preds = predict(W, b, inputs)\n", " label_probs = preds * targets + (1 - preds) * (1 - targets)\n", " return -np.sum(np.log(label_probs))\n", "\n", "# Initialize random model coefficients\n", "key, W_key, b_key = random.split(key, 3)\n", "W = random.normal(W_key, (3,))\n", "b = random.normal(b_key, ())" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "8Wk-Yai7ooh1" }, "cell_type": "markdown", "source": [ "Use the `grad` function with its `argnums` argument to differentiate a function with respect to positional arguments." ] }, { "metadata": { "colab_type": "code", "id": "bpmd8W8-6vu6", "outputId": "73a9b189-5c53-45ff-dff4-6068679ea598", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "cell_type": "code", "source": [ "# Differentiate `loss` with respect to the first positional argument:\n", "W_grad = grad(loss, argnums=0)(W, b)\n", "print('W_grad', W_grad)\n", "\n", "# Since argnums=0 is the default, this does the same thing:\n", "W_grad = grad(loss)(W, b)\n", "print('W_grad', W_grad)\n", "\n", "# But we can choose different values too, and drop the keyword:\n", "b_grad = grad(loss, 1)(W, b)\n", "print('b_grad', b_grad)\n", "\n", "# Including tuple values\n", "W_grad, b_grad = grad(loss, (0, 1))(W, b)\n", "print('W_grad', W_grad)\n", "print('b_grad', b_grad)" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "text": [ "W_grad [-0.16965586 -0.8774649 -1.4901347 ]\n", "W_grad [-0.16965586 -0.8774649 -1.4901347 ]\n", "b_grad -0.2922725\n", "W_grad [-0.16965586 -0.8774649 -1.4901347 ]\n", "b_grad -0.2922725\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "MDl5UZl4oyzB" }, "cell_type": "markdown", "source": [ "This `grad` API has a direct correspondence to the excellent notation in Spivak's classic *Calculus on Manifolds* (1965), also used in Sussman and Wisdom's [*Structure and Interpretation of Classical Mechanics*](http://mitpress.mit.edu/sites/default/files/titles/content/sicm_edition_2/book.html) (2015) and their [*Functional Differential Geometry*](https://mitpress.mit.edu/books/functional-differential-geometry) (2013). Both books are open-access. See in particular the \"Prologue\" section of *Functional Differential Geometry* for a defense of this notation.\n", "\n", "Essentially, when using the `argnums` argument, if `f` is a Python function for evaluating the mathematical function $f$, then the Python expression `grad(f, i)` evaluates to a Python function for evaluating $\\partial_i f$." ] }, { "metadata": { "colab_type": "text", "id": "fuz9E2vzro5E" }, "cell_type": "markdown", "source": [ "### Differentiating with respect to nested lists, tuples, and dicts" ] }, { "metadata": { "colab_type": "text", "id": "QQaPja7puMKi" }, "cell_type": "markdown", "source": [ "Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like." ] }, { "metadata": { "colab_type": "code", "id": "IY82kdAe6vu_", "outputId": "0a176cfb-0ae4-46c2-9eac-041cebc1eb3e", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "cell_type": "code", "source": [ "def loss2(params_dict):\n", " preds = predict(params_dict['W'], params_dict['b'], inputs)\n", " label_probs = preds * targets + (1 - preds) * (1 - targets)\n", " return -np.sum(np.log(label_probs))\n", "\n", "print(grad(loss2)({'W': W, 'b': b}))" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "{'b': array(-0.2922725, dtype=float32), 'W': array([-0.16965586, -0.8774649 , -1.4901347 ], dtype=float32)}\n" ], "name": "stdout" } ] }, { "metadata": { "id": "cJ2NxiN58bfI", "colab_type": "text" }, "cell_type": "markdown", "source": [ "You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." ] }, { "metadata": { "colab_type": "text", "id": "PaCHzAtGruBz" }, "cell_type": "markdown", "source": [ "### Evaluate a function and its gradient using `value_and_grad`" ] }, { "metadata": { "colab_type": "text", "id": "CSgCjjo-ssnA" }, "cell_type": "markdown", "source": [ "Another convenient function is `value_and_grad` for efficiently computing both a function's value as well as its gradient's value:" ] }, { "metadata": { "colab_type": "code", "id": "RsQSyT5p7OJW", "outputId": "32f9398e-3c19-41da-bdc9-48c6388a87eb", "colab": { "base_uri": "https://localhost:8080/", "height": 52 } }, "cell_type": "code", "source": [ "from jax import value_and_grad\n", "loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)\n", "print('loss value', loss_value)\n", "print('loss value', loss(W, b))" ], "execution_count": 11, "outputs": [ { "output_type": "stream", "text": [ "loss value 3.0519395\n", "loss value 3.0519395\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "rYTrH5tKllC_" }, "cell_type": "markdown", "source": [ "### Checking against numerical differences\n", "\n", "A great thing about derivatives is that they're straightforward to check with finite differences:" ] }, { "metadata": { "colab_type": "code", "id": "R8q5RiY3l7Fw", "outputId": "fbc64404-55ce-4373-b53b-eea3b4c33396", "colab": { "base_uri": "https://localhost:8080/", "height": 86 } }, "cell_type": "code", "source": [ "# Set a step size for finite differences calculations\n", "eps = 1e-4\n", "\n", "# Check b_grad with scalar finite differences\n", "b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps\n", "print('b_grad_numerical', b_grad_numerical)\n", "print('b_grad_autodiff', grad(loss, 1)(W, b))\n", "\n", "# Check W_grad with finite differences in a random direction\n", "key, subkey = random.split(key)\n", "vec = random.normal(subkey, W.shape)\n", "unitvec = vec / np.sqrt(np.vdot(vec, vec))\n", "W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps\n", "print('W_dirderiv_numerical', W_grad_numerical)\n", "print('W_dirderiv_autodiff', np.vdot(grad(loss)(W, b), unitvec))" ], "execution_count": 12, "outputs": [ { "output_type": "stream", "text": [ "b_grad_numerical -0.29325485\n", "b_grad_autodiff -0.2922725\n", "W_dirderiv_numerical -0.19550323\n", "W_dirderiv_autodiff -0.19909078\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "InzB-iiJpVcx" }, "cell_type": "markdown", "source": [ "JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:" ] }, { "metadata": { "colab_type": "code", "id": "6Ok2LEfQmOuy", "colab": {} }, "cell_type": "code", "source": [ "from jax.test_util import check_grads\n", "check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "id0DXxwt3VJi" }, "cell_type": "markdown", "source": [ "### Hessian-vector products with `grad`-of-`grad`\n", "\n", "One thing we can do with higher-order `grad` is build a Hessian-vector product function. (Later on we'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)\n", "\n", "A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).\n", "\n", "For a scalar-valued function $f : \\mathbb{R}^n \\to \\mathbb{R}$, the Hessian at a point $x \\in \\mathbb{R}^n$ is written as $\\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate\n", "\n", "$\\qquad v \\mapsto \\partial^2 f(x) \\cdot v$\n", "\n", "for any $v \\in \\mathbb{R}^n$.\n", "\n", "The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.\n", "\n", "Luckily, `grad` already gives us a way to write an efficient Hessian-vector product function. We just have to use the identity\n", "\n", "$\\qquad \\partial^2 f (x) v = \\partial [x \\mapsto \\partial f(x) \\cdot v] = \\partial g(x)$,\n", "\n", "where $g(x) = \\partial f(x) \\cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Nottice that we're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where we know `grad` is efficient.\n", "\n", "In JAX code, we can just write this:" ] }, { "metadata": { "colab_type": "code", "id": "Ou5OU-gU9epm", "colab": {} }, "cell_type": "code", "source": [ "def hvp(f, x, v):\n", " return grad(lambda x: np.vdot(grad(f)(x), v))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "Rb1-5Hpv-ZV0" }, "cell_type": "markdown", "source": [ "This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.\n", "\n", "We'll check this implementation a few cells down, once we see how to compute dense Hessian matrices." ] }, { "metadata": { "colab_type": "text", "id": "5A_akvtp8UTu" }, "cell_type": "markdown", "source": [ "## Jacobians and Hessians using `jacfwd` and `jacrev`" ] }, { "metadata": { "colab_type": "text", "id": "UP5BbmSm8ZwK" }, "cell_type": "markdown", "source": [ "You can compute full Jacobian matrices using the `jacfwd` and `jacrev` functions:" ] }, { "metadata": { "colab_type": "code", "id": "cbETzAvKvf5I", "outputId": "49be4d16-789b-4cc0-92ad-4da3ae9152a2", "colab": { "base_uri": "https://localhost:8080/", "height": 190 } }, "cell_type": "code", "source": [ "from jax import jacfwd, jacrev\n", "\n", "# Isolate the function from the weight matrix to the predictions\n", "f = lambda W: predict(W, b, inputs)\n", "\n", "J = jacfwd(f)(W)\n", "print(\"jacfwd result, with shape\", J.shape)\n", "print(J)\n", "\n", "J = jacrev(f)(W)\n", "print(\"jacrev result, with shape\", J.shape)\n", "print(J)" ], "execution_count": 15, "outputs": [ { "output_type": "stream", "text": [ "jacfwd result, with shape (4, 3)\n", "[[ 0.05981753 0.12883775 0.08857596]\n", " [ 0.04015912 -0.0492862 0.0068453 ]\n", " [ 0.1218829 0.01406341 -0.30470726]\n", " [ 0.00140427 -0.00472519 0.00263776]]\n", "jacrev result, with shape (4, 3)\n", "[[ 0.05981753 0.12883775 0.08857595]\n", " [ 0.04015912 -0.0492862 0.00684531]\n", " [ 0.1218829 0.01406341 -0.30470726]\n", " [ 0.00140427 -0.00472519 0.00263776]]\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "iZDL-n_AvgBt" }, "cell_type": "markdown", "source": [ "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`." ] }, { "metadata": { "id": "zeKlr7Xz8bfm", "colab_type": "text" }, "cell_type": "markdown", "source": [ "You can also use `jacfwd` and `jacrev` with container types:" ] }, { "metadata": { "id": "eH46Xnm88bfm", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 138 }, "outputId": "96212e27-23e6-4485-c75f-c96a8c6d13a0" }, "cell_type": "code", "source": [ "def predict_dict(params, inputs):\n", " return predict(params['W'], params['b'], inputs)\n", "\n", "J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)\n", "for k, v in J_dict.items():\n", " print(\"Jacobian from {} to logits is\".format(k))\n", " print(v)" ], "execution_count": 16, "outputs": [ { "output_type": "stream", "text": [ "Jacobian from b to logits is\n", "[0.11503371 0.04563536 0.2343902 0.00189767]\n", "Jacobian from W to logits is\n", "[[ 0.05981753 0.12883775 0.08857595]\n", " [ 0.04015912 -0.0492862 0.00684531]\n", " [ 0.1218829 0.01406341 -0.30470726]\n", " [ 0.00140427 -0.00472519 0.00263776]]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "yH34zjV88bfp", "colab_type": "text" }, "cell_type": "markdown", "source": [ "For more details on forward- and reverse-mode, as well as how to implement `jacfwd` and `jacrev` as efficiently as possible, read on!" ] }, { "metadata": { "id": "K6Mpw_7K8bfp", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Using a composition of two of these functions gives us a way to compute dense Hessian matrices:" ] }, { "metadata": { "colab_type": "code", "id": "n155ypD9rfIZ", "outputId": "1f9665bf-9459-4ee5-a902-125dcd14b225", "colab": { "base_uri": "https://localhost:8080/", "height": 294 } }, "cell_type": "code", "source": [ "def hessian(f):\n", " return jacfwd(jacrev(f))\n", "\n", "H = hessian(f)(W)\n", "print(\"hessian, with shape\", H.shape)\n", "print(H)" ], "execution_count": 17, "outputs": [ { "output_type": "stream", "text": [ "hessian, with shape (4, 3, 3)\n", "[[[ 0.02285464 0.04922539 0.03384245]\n", " [ 0.04922538 0.10602392 0.07289143]\n", " [ 0.03384245 0.07289144 0.05011286]]\n", "\n", " [[-0.03195212 0.03921397 -0.00544638]\n", " [ 0.03921397 -0.04812624 0.0066842 ]\n", " [-0.00544638 0.0066842 -0.00092836]]\n", "\n", " [[-0.01583708 -0.00182736 0.03959271]\n", " [-0.00182736 -0.00021085 0.00456839]\n", " [ 0.03959271 0.00456839 -0.09898178]]\n", "\n", " [[-0.00103521 0.00348334 -0.00194452]\n", " [ 0.00348334 -0.01172098 0.00654304]\n", " [-0.00194452 0.00654304 -0.00365254]]]\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "wvkk82R6uRoM" }, "cell_type": "markdown", "source": [ "This shape makes sense: if we start with a function $f : \\mathbb{R}^n \\to \\mathbb{R}^m$, then at a point $x \\in \\mathbb{R}^n$ we expect to get the shapes\n", "* $f(x) \\in \\mathbb{R}^m$, the value of $f$ at $x$,\n", "* $\\partial f(x) \\in \\mathbb{R}^{m \\times n}$, the Jacobian matrix at $x$,\n", "* $\\partial^2 f(x) \\in \\mathbb{R}^{m \\times n \\times n}$, the Hessian at $x$,\n", "\n", "and so on.\n", "\n", "To implement `hessian`, we could have used `jacrev(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of the two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \\mathbb{R}^n \\to \\mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\\nabla f : \\mathbb{R}^n \\to \\mathbb{R}^n$), which is where forward-mode wins out." ] }, { "metadata": { "colab_type": "text", "id": "OMmi9cyhs1bj" }, "cell_type": "markdown", "source": [ "## How it's made: two foundational autodiff functions" ] }, { "metadata": { "colab_type": "text", "id": "mtSRvouV6vvG" }, "cell_type": "markdown", "source": [ "### Jacobian-Vector products (JVPs, aka forward-mode autodiff)\n", "\n", "JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar `grad` function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.\n", "\n", "#### JVPs in math\n", "\n", "Mathematically, given a function $f : \\mathbb{R}^n \\to \\mathbb{R}^m$, the Jacobian matrix of $f$ evaluated at an input point $x \\in \\mathbb{R}^n$, denoted $\\partial f(x)$, is often thought of as a matrix in $\\mathbb{R}^m \\times \\mathbb{R}^n$:\n", "\n", "$\\qquad \\partial f(x) \\in \\mathbb{R}^{m \\times n}$.\n", "\n", "But we can also think of $\\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\\mathbb{R}^m$):\n", "\n", "$\\qquad \\partial f(x) : \\mathbb{R}^n \\to \\mathbb{R}^m$.\n", "\n", "This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map in a standard basis.\n", "\n", "If we don't commit to one specific input point $x$, then we can think of the function $\\partial f$ as first taking an input point and returning the Jacobian linear map at that input point:\n", "\n", "$\\qquad \\partial f : \\mathbb{R}^n \\to \\mathbb{R}^n \\to \\mathbb{R}^m$.\n", "\n", "In particular, we can uncurry things so that given input point $x \\in \\mathbb{R}^n$ and a tangent vector $v \\in \\mathbb{R}^n$, we get back an output tangent vector in $\\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as\n", "\n", "$\\qquad (x, v) \\mapsto \\partial f(x) v$\n", "\n", "#### JVPs in JAX code\n", "\n", "Back in Python code, JAX's `jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's `jvp` is a way to get a Python function for evaluating $(x, v) \\mapsto (f(x), \\partial f(x) v)$." ] }, { "metadata": { "colab_type": "code", "id": "pTncYR6F6vvG", "colab": {} }, "cell_type": "code", "source": [ "from jax import jvp\n", "\n", "# Isolate the function from the weight matrix to the predictions\n", "f = lambda W: predict(W, b, inputs)\n", "\n", "key, subkey = random.split(key)\n", "v = random.normal(subkey, W.shape)\n", "\n", "# Push forward the vector `v` along `f` evaluated at `W`\n", "y, u = jvp(f, (W,), (v,))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "m1VJgJYQGfCK" }, "cell_type": "markdown", "source": [ "In terms of Haskell-like type signatures, we could write\n", "\n", "```haskell\n", "jvp :: (a -> b) -> a -> T a -> (b, T b)\n", "```\n", "\n", "where we use `T a` to denote the type of the tangent space for `a`. In words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`." ] }, { "metadata": { "colab_type": "text", "id": "3RpbiasHGD3X" }, "cell_type": "markdown", "source": [ "The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a \"JVP rule\" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values.\n", "\n", "That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 2x the cost of just evaluating the function. Put another way, for a fixed primal point $x$, we can evaluate $v \\mapsto \\partial f(x) \\cdot v$ for about the same cost as evaluating $f$.\n", "\n", "That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning?\n", "\n", "To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with \"tall\" Jacobians, but inefficient for \"wide\" Jacobians.\n", "\n", "If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\\mathbb{R}^n$ to a scalar loss value in $\\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\\partial f(x) \\in \\mathbb{R}^{1 \\times n}$, which we often identify with the Gradient vector $\\nabla f(x) \\in \\mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.\n", "\n", "To do better for functions like this, we just need to use reverse-mode." ] }, { "metadata": { "colab_type": "text", "id": "PhkvkZazdXu1" }, "cell_type": "markdown", "source": [ "### Vector-Jacobian products (VJPs, aka reverse-mode autodiff)\n", "\n", "Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.\n", "\n", "#### VJPs in math\n", "\n", "Let's again consider a function $f : \\mathbb{R}^n \\to \\mathbb{R}^m$.\n", "Starting from our notation for JVPs, the notation for VJPs is pretty simple:\n", "\n", "$\\qquad (x, v) \\mapsto v \\partial f(x)$,\n", "\n", "where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \\mathbb{R}^m \\to \\mathbb{R}$, and when we write $v \\partial f(x)$ we mean function composition $(v \\circ \\partial f)(x)$. But in the common case we can identify it with a vector in $\\mathbb{R}^m$ and use the two almost interchageably, just like we might sometimes flip between \"column vectors\" and \"row vectors\" without much comment.\n", "\n", "With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:\n", "\n", "$\\qquad (x, v) \\mapsto \\partial f(x)^\\mathsf{T} v$.\n", "\n", "For a given point $x$, we can write the signature as\n", "\n", "$\\qquad \\partial f(x)^\\mathsf{T} : \\mathbb{R}^m \\to \\mathbb{R}^n$.\n", "\n", "The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry))\n", "of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function.\n", "\n", "#### VJPs in JAX code\n", "\n", "Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \\mapsto (f(x), v^\\mathsf{T} \\partial f(x))$." ] }, { "metadata": { "colab_type": "code", "id": "1tFcRuEzkGRR", "colab": {} }, "cell_type": "code", "source": [ "from jax import vjp\n", "\n", "# Isolate the function from the weight matrix to the predictions\n", "f = lambda W: predict(W, b, inputs)\n", "\n", "y, vjp_fun = vjp(f, W)\n", "\n", "key, subkey = random.split(key)\n", "u = random.normal(subkey, y.shape)\n", "\n", "# Pull back the covector `u` along `f` evaluated at `W`\n", "v = vjp_fun(u)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "oVOZexCEkvv3" }, "cell_type": "markdown", "source": [ "In terms of Haskell-like type signatures, we could write\n", "\n", "```haskell\n", "vjp :: (a -> b) -> a -> (b, CT b -> CT a)\n", "```\n", "\n", "where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.\n", "\n", "This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \\mapsto (f(x), v^\\mathsf{T} \\partial f(x))$ is only about twice the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \\mathbb{R}^n \\to \\mathbb{R}$, we can do it in just one call. That's how `grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.\n", "\n", "There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).\n", "\n", "For more on how reverse-mode works, see [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/)." ] }, { "metadata": { "colab_type": "text", "id": "xtqSUJgzwQXO" }, "cell_type": "markdown", "source": [ "## Composing VJPs, JVPs, and `vmap`" ] }, { "metadata": { "colab_type": "text", "id": "PSL1TciM6vvI" }, "cell_type": "markdown", "source": [ "### Jacobian-Matrix and Matrix-Jacobian products\n", "\n", "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's [`vmap` transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." ] }, { "metadata": { "colab_type": "code", "id": "asAWvxVaCmsx", "outputId": "a6348d8f-53c6-475e-a97a-46d052208cf1", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "cell_type": "code", "source": [ "# Isolate the function from the weight matrix to the predictions\n", "f = lambda W: predict(W, b, inputs)\n", "\n", "# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.\n", "# First, use a list comprehension to loop over rows in the matrix M.\n", "def loop_mjp(f, x, M):\n", " y, vjp_fun = vjp(f, x)\n", " return np.vstack([vjp_fun(mi) for mi in M])\n", "\n", "# Now, use vmap to build a computation that does a single fast matrix-matrix\n", "# multiply, rather than an outer loop over vector-matrix multiplies.\n", "def vmap_mjp(f, x, M):\n", " y, vjp_fun = vjp(f, x)\n", " return vmap(vjp_fun)(M)\n", "\n", "key = random.PRNGKey(0)\n", "num_covecs = 128\n", "U = random.normal(key, (num_covecs,) + y.shape)\n", "\n", "loop_vs = loop_mjp(f, W, M=U)\n", "print('Non-vmapped Matrix-Jacobian product')\n", "%timeit -n10 -r3 loop_mjp(f, W, M=U)\n", "\n", "print('\\nVmapped Matrix-Jacobian product')\n", "vmap_vs = vmap_mjp(f, W, M=U)\n", "%timeit -n10 -r3 vmap_mjp(f, W, M=U)\n", "\n", "assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'" ], "execution_count": 20, "outputs": [ { "output_type": "stream", "text": [ "Non-vmapped Matrix-Jacobian product\n", "10 loops, best of 3: 146 ms per loop\n", "\n", "Vmapped Matrix-Jacobian product\n", "10 loops, best of 3: 6.85 ms per loop\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "code", "id": "TDaxsJrlDraK", "outputId": "f38ba8d3-f81e-4187-eeda-65257056bfeb", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "cell_type": "code", "source": [ "def loop_jmp(f, x, M):\n", " # jvp immediately returns the primal and tangent values as a tuple,\n", " # so we'll compute and select the tangents in a list comprehension\n", " return np.vstack([jvp(f, (W,), (si,))[1] for si in S])\n", "\n", "def vmap_jmp(f, x, M):\n", " _jvp = lambda s: jvp(f, (W,), (s,))[1]\n", " return vmap(_jvp)(M)\n", "\n", "num_vecs = 128\n", "S = random.normal(key, (num_vecs,) + W.shape)\n", "\n", "loop_vs = loop_jmp(f, W, M=S)\n", "print('Non-vmapped Jacobian-Matrix product')\n", "%timeit -n10 -r3 loop_jmp(f, W, M=S)\n", "vmap_vs = vmap_jmp(f, W, M=S)\n", "print('\\nVmapped Jacobian-Matrix product')\n", "%timeit -n10 -r3 vmap_jmp(f, W, M=S)\n", "\n", "assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'" ], "execution_count": 21, "outputs": [ { "output_type": "stream", "text": [ "Non-vmapped Jacobian-Matrix product\n", "10 loops, best of 3: 525 ms per loop\n", "\n", "Vmapped Jacobian-Matrix product\n", "10 loops, best of 3: 5.6 ms per loop\n" ], "name": "stdout" } ] }, { "metadata": { "colab_type": "text", "id": "MXFEFBDz6vvL" }, "cell_type": "markdown", "source": [ "### The implementation of `jacfwd` and `jacrev`\n", "\n" ] }, { "metadata": { "id": "ZAgUb6sp8bf7", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write `jacfwd` and `jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once." ] }, { "metadata": { "colab_type": "code", "id": "HBEzsDH1U5_4", "colab": {} }, "cell_type": "code", "source": [ "from jax import jacrev as builtin_jacrev\n", "\n", "def our_jacrev(f):\n", " def jacfun(x):\n", " y, vjp_fun = vjp(f, x)\n", " # Use vmap to do a matrix-Jacobian product.\n", " # Here, the matrix is the Euclidean basis, so we get all\n", " # entries in the Jacobian at once. \n", " J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y)))\n", " return J\n", " return jacfun\n", "\n", "assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "code", "id": "Qd9gVZ5t6vvP", "colab": {} }, "cell_type": "code", "source": [ "from jax import jacfwd as builtin_jacfwd\n", "\n", "def our_jacfwd(f):\n", " def jacfun(x):\n", " _jvp = lambda s: jvp(f, (x,), (s,))[1]\n", " Jt =vmap(_jvp, in_axes=1)(np.eye(len(x)))\n", " return np.transpose(Jt)\n", " return jacfun\n", "\n", "assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "7r5_m9Y68bf_", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Interestingly, [Autograd](https://github.com/hips/autograd) couldn't do this. Our [implementation of reverse-mode `jacobian` in Autograd](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with `vmap`." ] }, { "metadata": { "id": "9maev0Nd8bf_", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Another thing that Autograd couldn't do is `jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use `jit` on the linear part of the computation. For example:" ] }, { "metadata": { "id": "_5jDflC08bgB", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "cb727406-e9bc-44b8-eb70-31636abef29c" }, "cell_type": "code", "source": [ "def f(x):\n", " try:\n", " if x < 3:\n", " return 2 * x ** 3\n", " else:\n", " raise ValueError\n", " except ValueError:\n", " return np.pi * x\n", "\n", "y, f_vjp = vjp(f, 4.)\n", "print(jit(f_vjp)(1.))" ], "execution_count": 24, "outputs": [ { "output_type": "stream", "text": [ "(array(3.1415927, dtype=float32),)\n" ], "name": "stdout" } ] }, { "metadata": { "id": "3fPWLrxK8bgD", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Complex numbers and differentiation" ] }, { "metadata": { "id": "2pZOHvrm8bgE", "colab_type": "text" }, "cell_type": "markdown", "source": [ "JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), JAX follows [Autograd's convention](https://github.com/HIPS/autograd/blob/master/docs/tutorial.md#complex-numbers) for encoding complex derivatives.\n", "\n", "Consider a complex-to-complex function $f: \\mathbb{C} \\to \\mathbb{C}$ that we break down into its component real-to-real functions:" ] }, { "metadata": { "id": "OaqZ2MuP8bgF", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def f(z):\n", " x, y = real(z), imag(z)\n", " return u(x, y), v(x, y) * 1j" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "3XB5oGxl8bgH", "colab_type": "text" }, "cell_type": "markdown", "source": [ "That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$. We define `grad(f)` to correspond to" ] }, { "metadata": { "id": "hINSv9TS8bgH", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def grad_f(z):\n", " x, y = real(z), imag(z)\n", " return grad(u, 0)(x, y) + grad(u, 1)(x, y) * 1j" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "4j0F28bB8bgK", "colab_type": "text" }, "cell_type": "markdown", "source": [ "In math symbols, that means we define $\\partial f(z) \\triangleq \\partial_0 u(x, y) + \\partial_1 u(x, y)$. So we throw out $v$, ignoring the complex component function of $f$ entirely!" ] }, { "metadata": { "id": "wLxn8qfC8bgL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This convention covers three important cases:\n", "1. If `f` evaluates a holomorphic function, then we get the usual complex derivative, since $\\partial_0 u = \\partial_1 v$ and $\\partial_1 u = - \\partial_0 v$.\n", "2. If `f` is evaluates the real-valued loss function of a complex parameter `x`, then we get a result that we can use in gradient-based optimization by taking steps in the direction of the conjugate of `grad(f)(x)`.\n", "3. If `f` evaluates a real-to-real function, but its implementation uses complex primitives internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then we get the same result that an implementation that only used real primitives would have given.\n", "\n", "By throwing away `v` entirely, this convention does not handle the case where `f` evaluates a non-holomorphic function and you want to evaluate all of $\\partial_0 u$, $\\partial_1 u$, $\\partial_0 v$, and $\\partial_1 v$ at once. But in that case the answer would have to contain four real values, and so there's no way to express it as a single complex number." ] }, { "metadata": { "id": "qmXkI37T8bgL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix:" ] }, { "metadata": { "id": "WrDHHfKI8bgM", "colab_type": "code", "colab": {}, "outputId": "92338540-e4e8-45f2-a614-c2d2dfc8d5ba" }, "cell_type": "code", "source": [ "A = np.array([[5., 2.+3j, 5j],\n", " [2.-3j, 7., 1.+7j],\n", " [-5j, 1.-7j, 12.]])\n", "\n", "def f(X):\n", " L = np.linalg.cholesky(X)\n", " return np.sum((L - np.sin(L))**2)\n", "\n", "grad(f)(A)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[ 1.6623291e+01+0.j , -1.3631370e+00-5.6038527j,\n", " -1.8995690e+00+9.700885j ],\n", " [-1.3631370e+00+5.6038527j, -8.9385948e+00+0.j ,\n", " -5.1351528e+00-6.5743794j],\n", " [-1.8995690e+00-9.700885j , -5.1351528e+00+6.5743794j,\n", " 1.3204219e-02+0.j ]], dtype=complex64)" ] }, "metadata": { "tags": [] }, "execution_count": 25 } ] }, { "metadata": { "id": "X_3a2b6M8bgP", "colab_type": "text" }, "cell_type": "markdown", "source": [ "For primitives' JVP rules, writing the primals as $z = a + bi$ and the tangents as $t = c + di$, we define the Jacobian-vector product $t \\mapsto \\partial f(z) \\cdot t$ as\n", "\n", "$t \\mapsto\n", "\\begin{matrix} \\begin{bmatrix} 1 & 1 \\end{bmatrix} \\\\ ~ \\end{matrix}\n", "\\begin{bmatrix} \\partial_0 u(a, b) & -\\partial_0 v(a, b) \\\\ - \\partial_1 u(a, b) i & \\partial_1 v(a, b) i \\end{bmatrix}\n", "\\begin{bmatrix} c \\\\ d \\end{bmatrix}$." ] }, { "metadata": { "id": "dWwnI2Iz8bgP", "colab_type": "text" }, "cell_type": "markdown", "source": [ "See Chapter 4 of [Dougal's PhD thesis](https://dougalmaclaurin.com/phd-thesis.pdf) for more details." ] }, { "metadata": { "id": "Pgr2A60q9gl1", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# More advanced autodiff\n", "\n", "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. \n", "\n", "There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in a \"Advanced Autodiff Cookbook\" include:\n", "\n", " - Gauss-Newton Vector Products, linearizing once\n", " - Custom VJPs and JVPs\n", " - Efficient derivatives at fixed-points\n", " - Estimating the trace of a Hessian using random Hessian-vector products.\n", " - Forward-mode autodiff using only reverse-mode autodiff.\n", " - Taking derivatives with respect to custom data types.\n", " - Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting).\n", " - Optimizing VJPs with Jacobian pre-accumulation." ] } ] }