{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "JAX Quickstart.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "accelerator": "GPU" }, "cells": [ { "metadata": { "id": "logZcM_HEnve", "colab_type": "text" }, "cell_type": "markdown", "source": [ "##### Copyright 2018 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "metadata": { "id": "QwN47xiBEsKz", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", "https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "metadata": { "colab_type": "text", "id": "xtWX4x9DCF5_" }, "cell_type": "markdown", "source": [ "# JAX Quickstart\n", "Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary\n", "\n", "![](https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png)\n", "\n", "#### [JAX](https://github.com/google/jax) is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.\n", "\n", "With its updated version of [Autograd](https://github.com/hips/autograd), JAX\n", "can automatically differentiate native Python and NumPy code. It can\n", "differentiate through a large subset of Python’s features, including loops, ifs,\n", "recursion, and closures, and it can even take derivatives of derivatives of\n", "derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily\n", "to any order.\n", "\n", "What’s new is that JAX uses\n", "[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)\n", "to compile and run your NumPy code on accelerators, like GPUs and TPUs.\n", "Compilation happens under the hood by default, with library calls getting\n", "just-in-time compiled and executed. But JAX even lets you just-in-time compile\n", "your own Python functions into XLA-optimized kernels using a one-function API.\n", "Compilation and automatic differentiation can be composed arbitrarily, so you\n", "can express sophisticated algorithms and get maximal performance without having\n", "to leave Python.\n" ] }, { "metadata": { "colab_type": "code", "id": "PaW85yP_BrCF", "colab": {} }, "cell_type": "code", "source": [ "!pip install --upgrade https://storage.googleapis.com/jax-wheels/cuda92/jaxlib-0.1.6-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade jax" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "code", "id": "SY8mDvEvCGqk", "colab": {} }, "cell_type": "code", "source": [ "from __future__ import print_function, division\n", "import jax.numpy as np\n", "from jax import grad, jit, vmap\n", "from jax import random" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "FQ89jHCYfhpg" }, "cell_type": "markdown", "source": [ "### Multiplying Matrices" ] }, { "metadata": { "colab_type": "text", "id": "Xpy1dSgNqCP4" }, "cell_type": "markdown", "source": [ "We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see the readme." ] }, { "metadata": { "colab_type": "code", "id": "u0nseKZNqOoH", "colab": {} }, "cell_type": "code", "source": [ "key = random.PRNGKey(0)\n", "x = random.normal(key, (10,))\n", "print(x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "hDJF0UPKnuqB" }, "cell_type": "markdown", "source": [ "Let's dive right in and multiply two big matrices." ] }, { "metadata": { "colab_type": "code", "id": "eXn8GUl6CG5N", "colab": {} }, "cell_type": "code", "source": [ "size = 3000\n", "x = random.normal(key, (size, size), dtype=np.float32)\n", "%timeit np.dot(x, x.T) # runs on the GPU" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "0AlN7EbonyaR" }, "cell_type": "markdown", "source": [ "JAX NumPy functions work on regular NumPy arrays." ] }, { "metadata": { "colab_type": "code", "id": "ZPl0MuwYrM7t", "colab": {} }, "cell_type": "code", "source": [ "import numpy as onp # original CPU-backed NumPy\n", "x = onp.random.normal(size=(size, size)).astype(onp.float32)\n", "%timeit np.dot(x, x.T)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "_SrcB2IurUuE" }, "cell_type": "markdown", "source": [ "That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using `device_put`." ] }, { "metadata": { "colab_type": "code", "id": "Jj7M7zyRskF0", "colab": {} }, "cell_type": "code", "source": [ "from jax import device_put\n", "\n", "x = onp.random.normal(size=(size, size)).astype(onp.float32)\n", "x = device_put(x)\n", "%timeit np.dot(x, x.T)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "clO9djnen8qi" }, "cell_type": "markdown", "source": [ "The output of `device_put` still acts like an NDArray. By the way, the [implementation](https://github.com/google/jax/blob/c8f93511ecb977d02fa5b0eee17075706fd6fd93/jax/api.py#L162) of `device_put` is just `device_put = jit(lambda x: x)`." ] }, { "metadata": { "colab_type": "text", "id": "ghkfKNQttDpg" }, "cell_type": "markdown", "source": [ "If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU." ] }, { "metadata": { "colab_type": "code", "id": "RzXK8GnIs7VV", "colab": {} }, "cell_type": "code", "source": [ "x = onp.random.normal(size=(size, size)).astype(onp.float32)\n", "%timeit onp.dot(x, x.T)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "iOzp0P_GoJhb" }, "cell_type": "markdown", "source": [ "JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there's three main ones:\n", "\n", " - `jit`, for speeding up your code\n", " - `grad`, for taking derivatives\n", " - `vmap`, for automatic vectorization or batching.\n", "\n", "Let's go over these, one-by-one. We'll also end up composing these in interesting ways." ] }, { "metadata": { "colab_type": "text", "id": "bTTrTbWvgLUK" }, "cell_type": "markdown", "source": [ "### Using `jit` to speed up functions" ] }, { "metadata": { "colab_type": "text", "id": "YrqE32mvE3b7" }, "cell_type": "markdown", "source": [ "JAX runs transparently on the GPU (or CPU, if you don't have one, and TPU coming soon!). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md). Let's try that." ] }, { "metadata": { "colab_type": "code", "id": "qLGdCtFKFLOR", "colab": {} }, "cell_type": "code", "source": [ "def selu(x, alpha=1.67, lmbda=1.05):\n", " return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)\n", "\n", "x = random.normal(key, (1000000,))\n", "%timeit selu(x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "a_V8SruVHrD_" }, "cell_type": "markdown", "source": [ "We can speed it up with `@jit`, which will jit-compile the first time `selu` is called and will be cached thereafter." ] }, { "metadata": { "colab_type": "code", "id": "fh4w_3NpFYTp", "colab": {} }, "cell_type": "code", "source": [ "selu_jit = jit(selu)\n", "%timeit selu_jit(x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "HxpBc4WmfsEU" }, "cell_type": "markdown", "source": [ "### Taking derivatives with `grad`\n", "\n", "In addition to evaluating numerical functions, we also want to transform them. One transformation is [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation). In JAX, just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the `grad` function." ] }, { "metadata": { "colab_type": "code", "id": "IMAgNJaMJwPD", "colab": {} }, "cell_type": "code", "source": [ "def sum_logistic(x):\n", " return np.sum(1.0 / (1.0 + np.exp(-x)))\n", "\n", "x_small = np.arange(3.)\n", "derivative_fn = grad(sum_logistic)\n", "print(derivative_fn(x_small))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "PtNs881Ohioc" }, "cell_type": "markdown", "source": [ "Let's verify with finite differences that our result is correct." ] }, { "metadata": { "colab_type": "code", "id": "JXI7_OZuKZVO", "colab": {} }, "cell_type": "code", "source": [ "def first_finite_differences(f, x):\n", " eps = 1e-3\n", " return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n", " for v in onp.eye(len(x))])\n", "\n", "\n", "print(first_finite_differences(sum_logistic, x_small))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "Q2CUZjOWNZ-3" }, "cell_type": "markdown", "source": [ "Taking derivatives is as easy as calling `grad`. `grad` and `jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:" ] }, { "metadata": { "colab_type": "code", "id": "TO4g8ny-OEi4", "colab": {} }, "cell_type": "code", "source": [ "print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "yCJ5feKvhnBJ" }, "cell_type": "markdown", "source": [ "For more advanced autodiff, you can use `jax.vjp` for reverse-mode vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:" ] }, { "metadata": { "colab_type": "code", "id": "Z-JxbiNyhxEW", "colab": {} }, "cell_type": "code", "source": [ "from jax import jacfwd, jacrev\n", "def hessian(fun):\n", " return jit(jacfwd(jacrev(fun)))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "TI4nPsGafxbL" }, "cell_type": "markdown", "source": [ "### Auto-vectorization with `vmap`" ] }, { "metadata": { "colab_type": "text", "id": "PcxkONy5aius" }, "cell_type": "markdown", "source": [ "JAX has one more transformation in its API that you might find useful: `vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with `jit`, it can be just as fast as adding the batch dimensions by hand." ] }, { "metadata": { "colab_type": "text", "id": "TPiX4y-bWLFS" }, "cell_type": "markdown", "source": [ "We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using `vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions." ] }, { "metadata": { "colab_type": "code", "id": "8w0Gpsn8WYYj", "colab": {} }, "cell_type": "code", "source": [ "mat = random.normal(key, (150, 100))\n", "batched_x = random.normal(key, (10, 100))\n", "\n", "def apply_matrix(v):\n", " return np.dot(mat, v)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "0zWsc0RisQWx", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Given a function such as `apply_matrix`, we can loop over a batch dimension in Python, but usually the performance of doing so is poor." ] }, { "metadata": { "colab_type": "code", "id": "KWVc9BsZv0Ki", "colab": {} }, "cell_type": "code", "source": [ "def naively_batched_apply_matrix(v_batched):\n", " return np.stack([apply_matrix(v) for v in v_batched])\n", "\n", "print('Naively batched')\n", "%timeit naively_batched_apply_matrix(batched_x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "qHfKaLE9stbA", "colab_type": "text" }, "cell_type": "markdown", "source": [ "We know how to batch this operation manually. In this case, `np.dot` handles extra batch dimensions transparently." ] }, { "metadata": { "colab_type": "code", "id": "ipei6l8nvrzH", "colab": {} }, "cell_type": "code", "source": [ "@jit\n", "def batched_apply_matrix(v_batched):\n", " return np.dot(v_batched, mat.T)\n", "\n", "print('Manually batched')\n", "%timeit batched_apply_matrix(batched_x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "1eF8Nhb-szAb", "colab_type": "text" }, "cell_type": "markdown", "source": [ "However, suppose we had a more complicated function without batching support. We can use `vmap` to add batching support automatically." ] }, { "metadata": { "colab_type": "code", "id": "67Oeknf5vuCl", "colab": {} }, "cell_type": "code", "source": [ "@jit\n", "def vmap_batched_apply_matrix(batched_x):\n", " return vmap(apply_matrix)(batched_x)\n", "\n", "print('Auto-vectorized with vmap')\n", "%timeit vmap_batched_apply_matrix(batched_x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "text", "id": "pYVl3Z2nbZhO" }, "cell_type": "markdown", "source": [ "Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other JAX transformation." ] }, { "metadata": { "id": "WwNnjaI4th_8", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This is just a taste of what JAX can do. We're really excited to see what you do with it!" ] } ] }