{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "JAX generalized ufuncs.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "metadata": { "id": "435_M09vl3NA", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Extending JAX's vmap to work like NumPy's gufuncs\n", "\n", "by [Stephan Hoyer](https://github.com/shoyer)\n", "\n", "## What is a gufunc?\n", "\n", "[Generalized universal functions](https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html) (\"gufuncs\") are one of my favorite abstractions from NumPy. They generalize NumPy's [broadcasting rules](https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html) to handle non-scalar operations. When a gufuncs is applied to arrays, there are:\n", "- \"core dimensions\" over which an operation is defined.\n", "- \"broadcast dimensions\" over which operations can be automatically vectorized.\n", "\n", "A string [signature](https://docs.scipy.org/doc/numpy-1.15.0/reference/c-api.generalized-ufuncs.html#details-of-signature) associated with each gufunc controls how this happens by indicating how core dimensions are mapped between inputs and outputs. The syntax is easiest to understand by looking at a few examples:\n", "\n", "- Addition: `(),()->()`\n", "- 1D inner product: `(i),(i)->()`\n", "- 1D sum: `(i)->()`\n", "- Matrix multiplcation: `(m,n),(n,k)->(m,k)`\n", "\n", "## Why write gufuncs?\n", "\n", "From a user perspective, gufuncs are nice because they're guaranteed to vectorize in a consistent and general fashion. For example, by default gufuncs use the last dimensions of arrays as core dimensions, but you can control that explicitly with the `axis` or `axes` arguments.\n", "\n", "From a developer perspective, gufuncs are nice because they simply your work: you only need to think about the core logic of your function, not how it handles arbitrary dimensional input. You can just write that down in a simple, declarative way.\n", "\n", "## JAX makes it easy to write high-level performant code\n", "\n", "Unfortunately, writing NumPy gufuncs today is somewhat non-trivial. Your options today are:\n", "\n", "1. Write the inner loops yourself in C.\n", "2. [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) creates something kind of like a gufunc, but it's painfully slow: the outer loop is performed in Python.\n", "3. [`numba.guvectorize`](https://numba.pydata.org/numba-doc/dev/user/vectorize.html) can work well, if you don't need further code transformations like automatic differentiation.\n", "\n", "JAX's `vmap` contains all the core functionality we need to write functions that work like gufuncs. JAX gufuncs play nicely with other transformations like `grad` and `jit`.\n", "\n", "## A simple example\n", "\n", "Consider a simple example from data preprocessing, centering an array.\n", "\n", "Here's how we might write a vectorized version using NumPy:\n", "```python\n", "def center(array, axis=-1):\n", " # array can have any number of dimensions\n", " bias = np.mean(array, axis=axis)\n", " debiased = array - np.expand_dims(bias, axis)\n", " return bias, debiased\n", "```\n", "\n", "And here's how we could write a vectorized version using JAX gufuncs:\n", "```python\n", "@vectorize('(n)->(),(n)')\n", "def center(array):\n", " # array is always a 1D vector\n", " bias = np.mean(array)\n", " debiased = array - bias\n", " return bias, debiased\n", "```\n", "\n", "See the difference?\n", "- Instead of needing to think about broadcasting while writing the entire function, we can write the function assuming the input is always a vector.\n", "- We get the `axis` argument automatically, without needing to write it ourselves.\n", "- As a bonus, the decorator makes the function self-documenting: a reader immediately knows that it handles higher dimensional input and output correctly.\n", "\n", "For more examples (and the implementation) see below." ] }, { "metadata": { "id": "k40qkuQdkqFg", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Implementation\n", "\n", "### License\n", "\n", "Copyright 2018 Google LLC.\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", "\n", "https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n" ] }, { "metadata": { "id": "4QrBJNYG5ECU", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Setup and imports" ] }, { "metadata": { "id": "2NXj3Dp5270W", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.13-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade -q jax" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "p-rYDdqL1uZP", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from jax import grad, jit, vmap\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import re" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "tU2rwIOZmT0Q", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Copied from `numpy.lib.function_base`" ] }, { "metadata": { "id": "lBVIP3O2kkqY", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html\n", "_DIMENSION_NAME = r'\\w+'\n", "_CORE_DIMENSION_LIST = '(?:{0:}(?:,{0:})*)?'.format(_DIMENSION_NAME)\n", "_ARGUMENT = r'\\({}\\)'.format(_CORE_DIMENSION_LIST)\n", "_ARGUMENT_LIST = '{0:}(?:,{0:})*'.format(_ARGUMENT)\n", "_SIGNATURE = '^{0:}->{0:}$'.format(_ARGUMENT_LIST)\n", "\n", "\n", "def _parse_gufunc_signature(signature):\n", " \"\"\"\n", " Parse string signatures for a generalized universal function.\n", "\n", " Arguments\n", " ---------\n", " signature : string\n", " Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``\n", " for ``np.matmul``.\n", "\n", " Returns\n", " -------\n", " Tuple of input and output core dimensions parsed from the signature, each\n", " of the form List[Tuple[str, ...]].\n", " \"\"\"\n", " if not re.match(_SIGNATURE, signature):\n", " raise ValueError(\n", " 'not a valid gufunc signature: {}'.format(signature))\n", " return tuple([tuple(re.findall(_DIMENSION_NAME, arg))\n", " for arg in re.findall(_ARGUMENT, arg_list)]\n", " for arg_list in signature.split('->'))\n", "\n", "\n", "\n", "def _update_dim_sizes(dim_sizes, arg, core_dims):\n", " \"\"\"\n", " Incrementally check and update core dimension sizes for a single argument.\n", "\n", " Arguments\n", " ---------\n", " dim_sizes : Dict[str, int]\n", " Sizes of existing core dimensions. Will be updated in-place.\n", " arg : ndarray\n", " Argument to examine.\n", " core_dims : Tuple[str, ...]\n", " Core dimensions for this argument.\n", " \"\"\"\n", " if not core_dims:\n", " return\n", "\n", " num_core_dims = len(core_dims)\n", " if arg.ndim < num_core_dims:\n", " raise ValueError(\n", " '%d-dimensional argument does not have enough '\n", " 'dimensions for all core dimensions %r'\n", " % (arg.ndim, core_dims))\n", "\n", " core_shape = arg.shape[-num_core_dims:]\n", " for dim, size in zip(core_dims, core_shape):\n", " if dim in dim_sizes:\n", " if size != dim_sizes[dim]:\n", " raise ValueError(\n", " 'inconsistent size for core dimension %r: %r vs %r'\n", " % (dim, size, dim_sizes[dim]))\n", " else:\n", " dim_sizes[dim] = size\n", "\n", "\n", "def _parse_input_dimensions(args, input_core_dims):\n", " \"\"\"\n", " Parse broadcast and core dimensions for vectorize with a signature.\n", "\n", " Arguments\n", " ---------\n", " args : Tuple[ndarray, ...]\n", " Tuple of input arguments to examine.\n", " input_core_dims : List[Tuple[str, ...]]\n", " List of core dimensions corresponding to each input.\n", "\n", " Returns\n", " -------\n", " broadcast_shape : Tuple[int, ...]\n", " Common shape to broadcast all non-core dimensions to.\n", " dim_sizes : Dict[str, int]\n", " Common sizes for named core dimensions.\n", " \"\"\"\n", " broadcast_args = []\n", " dim_sizes = {}\n", " for arg, core_dims in zip(args, input_core_dims):\n", " _update_dim_sizes(dim_sizes, arg, core_dims)\n", " ndim = arg.ndim - len(core_dims)\n", " dummy_array = np.lib.stride_tricks.as_strided(0, arg.shape[:ndim])\n", " broadcast_args.append(dummy_array)\n", " broadcast_shape = np.lib.stride_tricks._broadcast_shape(*broadcast_args)\n", " return broadcast_shape, dim_sizes\n", "\n", "\n", "def _calculate_shapes(broadcast_shape, dim_sizes, list_of_core_dims):\n", " \"\"\"Helper for calculating broadcast shapes with core dimensions.\"\"\"\n", " return [broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)\n", " for core_dims in list_of_core_dims]\n", "\n", " \n", "# adapted from np.vectorize (again authored by shoyer@)\n", "def broadcast_with_core_dims(args, input_core_dims, output_core_dims):\n", " if len(args) != len(input_core_dims):\n", " raise TypeError('wrong number of positional arguments: '\n", " 'expected %r, got %r'\n", " % (len(input_core_dims), len(args)))\n", "\n", " broadcast_shape, dim_sizes = _parse_input_dimensions(\n", " args, input_core_dims)\n", " input_shapes = _calculate_shapes(broadcast_shape, dim_sizes,\n", " input_core_dims)\n", " args = [jnp.broadcast_to(arg, shape)\n", " for arg, shape in zip(args, input_shapes)]\n", " return args" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "aa_Gh3K_PQkY", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Handle the `axis` argument" ] }, { "metadata": { "id": "MLeFHhVoPT4h", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "def verify_axis_is_supported(input_core_dims, output_core_dims):\n", " all_core_dims = set()\n", " for input_or_output_core_dims in [input_core_dims, output_core_dims]:\n", " for core_dims in input_or_output_core_dims:\n", " all_core_dims.update(core_dims)\n", " if len(core_dims) > 1:\n", " raise ValueError('only one gufuncs with one core dim support axis')\n", "\n", "\n", "def reorder_inputs(args, axis, input_core_dims):\n", " return tuple(jnp.moveaxis(arg, axis, -1) if core_dims else arg\n", " for arg, core_dims in zip(args, input_core_dims))\n", "\n", "\n", "def reorder_outputs(result, axis, output_core_dims):\n", " if not isinstance(result, tuple):\n", " result = (result,)\n", " result = tuple(jnp.moveaxis(res, -1, axis) if core_dims else res\n", " for res, core_dims in zip(result, output_core_dims))\n", " if len(result) == 1:\n", " (result,) = result\n", " return result" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Uik9GA76lZjY", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Core implementation\n", "\n", "This is the only part that uses `vmap`" ] }, { "metadata": { "id": "z-FgQaW02_WN", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import functools\n", "\n", "\n", "def vectorize(signature):\n", " \"\"\"Vectorize a function using JAX.\"\"\"\n", " input_core_dims, output_core_dims = _parse_gufunc_signature(signature)\n", " \n", " def decorator(func):\n", " @functools.wraps(func)\n", " def wrapper(*args, axis=None):\n", "\n", " if axis is not None:\n", " verify_axis_is_supported(input_core_dims, output_core_dims)\n", " args = reorder_inputs(args, axis, input_core_dims)\n", "\n", " broadcast_args = broadcast_with_core_dims(\n", " args, input_core_dims, output_core_dims)\n", " num_batch_dims = len(broadcast_args[0].shape) - len(input_core_dims[0])\n", "\n", " vectorized_func = func\n", " for _ in range(num_batch_dims):\n", " vectorized_func = vmap(vectorized_func)\n", " result = vectorized_func(*broadcast_args)\n", "\n", " if axis is not None:\n", " result = reorder_outputs(result, axis, output_core_dims)\n", "\n", " return result\n", " return wrapper\n", " return decorator" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "EWDCFZiqmY9A", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Test cases\n" ] }, { "metadata": { "id": "W-jCowsgj_Tg", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### matmul" ] }, { "metadata": { "id": "gSJ7G_da4ArE", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "matmat = vectorize('(n,m),(m,k)->(n,k)')(jnp.dot)\n", "matvec = vectorize('(n,m),(m)->(n)')(jnp.dot)\n", "vecmat = vectorize('(m),(m,k)->(k)')(jnp.dot)\n", "vecvec = vectorize('(m),(m)->()')(jnp.dot)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "CI-vJzjMfPXS", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "assert matmat(np.zeros((2, 3)), np.zeros((3, 4))).shape == (2, 4)\n", "assert matmat(np.zeros((2, 3)), np.zeros((1, 3, 4))).shape == (1, 2, 4)\n", "assert matmat(np.zeros((5, 2, 3)), np.zeros((1, 3, 4))).shape == (5, 2, 4)\n", "assert matmat(np.zeros((6, 5, 2, 3)), np.zeros((3, 4))).shape == (6, 5, 2, 4)\n", "\n", "assert matvec(np.zeros((2, 3)), np.zeros((3,))).shape == (2,)\n", "assert matvec(np.zeros((2, 3)), np.zeros((1, 3))).shape == (1, 2)\n", "assert matvec(np.zeros((4, 2, 3)), np.zeros((1, 3))).shape == (4, 2)\n", "assert matvec(np.zeros((5, 4, 2, 3)), np.zeros((1, 3))).shape == (5, 4, 2)\n", "\n", "assert vecvec(np.zeros((3,)), np.zeros((3,))).shape == ()\n", "# these next two don't work yet\n", "# assert vecvec(np.zeros((2, 3)), np.zeros((3,))).shape == (2,)\n", "# assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2) " ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "u5xKzwoRkKuR", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### magnitude" ] }, { "metadata": { "id": "Rcbol3OHkKUQ", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "@vectorize('(n)->()')\n", "def magnitude(x):\n", " return jnp.dot(x, x)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "DBtX_QDwkMbI", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "assert magnitude(np.arange(3.0)).shape == ()\n", "# these next two don't work yet\n", "# assert magnitude(np.arange(6.0).reshape(2, 3)).shape == (2,)\n", "# assert magnitude(np.arange(6.0).reshape(1, 2, 3)).shape == (1, 2,)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "LFlyTMg0kCm5", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### mean" ] }, { "metadata": { "id": "m5HrLVmehaHx", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "mean = vectorize('(n)->()')(jnp.mean)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "QBtnkLwnhhJY", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "assert mean(np.zeros((3,))).shape == ()\n", "assert mean(np.zeros((2, 3,))).shape == (2,)\n", "assert mean(np.zeros((2, 3,)), axis=0).shape == (3,)\n", "assert mean(np.zeros((1, 2, 3, 4))).shape == (1, 2, 3)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "hAMhRkK4kEzw", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### center" ] }, { "metadata": { "id": "bDsjXm7MitcX", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "@vectorize('(n)->(),(n)')\n", "def center(array):\n", " bias = jnp.mean(array)\n", " debiased = array - bias\n", " return bias, debiased" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "MSyxOrUPixDI", "colab_type": "code", "outputId": "65ff7d75-6658-402b-85f7-558924ebe934", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "b, a = center(jnp.arange(3))\n", "print(np.array(a), np.array(b))" ], "execution_count": 13, "outputs": [ { "output_type": "stream", "text": [ "[-1. 0. 1.] 1.0\n" ], "name": "stdout" } ] }, { "metadata": { "id": "7UZGOS8FGT_D", "colab_type": "code", "outputId": "28dd7afe-2bbe-485a-ac6e-f8ca6bac9f37", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "X = jnp.arange(12).reshape(3, 4)\n", "X" ], "execution_count": 14, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[ 0, 1, 2, 3],\n", " [ 4, 5, 6, 7],\n", " [ 8, 9, 10, 11]])" ] }, "metadata": { "tags": [] }, "execution_count": 14 } ] }, { "metadata": { "id": "n2Fz91ptjM_7", "colab_type": "code", "outputId": "9ea07d50-7cbc-423c-bffb-92047dcca62c", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "b, a = center(X, axis=1)\n", "print(np.array(a), np.array(b))" ], "execution_count": 15, "outputs": [ { "output_type": "stream", "text": [ "[[-1.5 -0.5 0.5 1.5]\n", " [-1.5 -0.5 0.5 1.5]\n", " [-1.5 -0.5 0.5 1.5]] [1.5 5.5 9.5]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "G-AyCkAK4RKT", "colab_type": "code", "outputId": "dccce2b9-a7ec-4634-d45d-d440ddc1e1fe", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "b, a = center(X, axis=0)\n", "print(np.array(a), np.array(b))" ], "execution_count": 16, "outputs": [ { "output_type": "stream", "text": [ "[[-4. -4. -4. -4.]\n", " [ 0. 0. 0. 0.]\n", " [ 4. 4. 4. 4.]] [4. 5. 6. 7.]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "_FhnjYMUjZgI", "colab_type": "code", "outputId": "96c71fea-b30e-4057-9634-22f4d62218e2", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "# NOTE: using the wrapped function directly silently gives the wrong result!\n", "b, a = center.__wrapped__(X)\n", "print(np.array(a), np.array(b))" ], "execution_count": 17, "outputs": [ { "output_type": "stream", "text": [ "[[-5.5 -4.5 -3.5 -2.5]\n", " [-1.5 -0.5 0.5 1.5]\n", " [ 2.5 3.5 4.5 5.5]] 5.5\n" ], "name": "stdout" } ] }, { "metadata": { "id": "1EeD7aFdENt8", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "" ], "execution_count": 0, "outputs": [] } ] }