{ "cells": [ { "cell_type": "markdown", "id": "57be61cb", "metadata": {}, "source": [ "# Jax Parallelization Tricks\n", "\n", "\n", "While differentiability is a great reason to use JAX, arguably\n", "its APIs for parallelization are an equally good reason\n", "\n", "In standard..\n" ] }, { "cell_type": "markdown", "id": "5b03fa84", "metadata": {}, "source": [ "## An Example\n", "\n", "Let's say you have a function computing something\n", "interesting for some inputs `x`.\n", "\n", "For example we can use the text book definition\n", "of a dense layer in a neural network\n", "\n", "$$y = W x + b$$\n", "\n", "or written with indices: \n", "\n", "$$y_{i} = W_{ij} x_{j} + b_{i}$$\n", "\n" ] }, { "cell_type": "code", "execution_count": 90, "id": "b964aa1a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.1 1.2] -> [0.90000004 1.53 ]\n" ] } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "def dense_layer(x):\n", " W = jnp.array([[0.2,0.4],[0.9,0.2]])\n", " b = jnp.array([0.2,0.3])\n", " return W.dot(x) + b\n", "\n", "inputs = jnp.array([1.1,1.2])\n", "out = dense_layer(inputs)\n", "print(f'{inputs} -> {out}')" ] }, { "cell_type": "markdown", "id": "0a360820", "metadata": {}, "source": [ "But in practice we always want to evaluate on a mini-batch of $x$ with a batch dimension\n", "\n", "$$x = x_{bi}$$\n", "\n", "The linear layer would thus be in index-notation\n", "\n", "$$y_{bi} = W_{ij} x_{bj} + b_{i}$$\n", "\n", "but this is not compatible with our code `W.dot(x)` above as the batch index \n", "is in the way" ] }, { "cell_type": "code", "execution_count": 28, "id": "6a837bf0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Incompatible shapes for dot: got (2, 2) and (3, 2).\n" ] } ], "source": [ "inputs = jnp.array([\n", " [1.1,1.2],\n", " [2.1,2.2],\n", " [3.1,3.2]\n", "\n", "])\n", "\n", "try:\n", " dense_layer(inputs)\n", "except TypeError as exc:\n", " print(exc)" ] }, { "cell_type": "markdown", "id": "4ace1791", "metadata": {}, "source": [ "## Workarounds\n", "\n", "### Einsums\n", "\n", "We can work around it by adjusting our dense layer code and using `einsum` tricks\n", "to incorporate the new batch dimension:" ] }, { "cell_type": "code", "execution_count": 44, "id": "86ab5297", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.90000004, 1.53 ],\n", " [1.5000001 , 2.6299999 ],\n", " [2.1000001 , 3.7299998 ]], dtype=float32)" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def einsum_layer(x):\n", " W = jnp.array([[0.2,0.4],[0.9,0.2]])\n", " b = jnp.array([0.2,0.3])\n", " return jnp.einsum('ij,bj->bi',W,x) + b\n", "\n", "einsum_layer(inputs)" ] }, { "cell_type": "markdown", "id": "f7e68267", "metadata": {}, "source": [ "## Transposition\n", "\n", "Another workaround you'll see frequently in textbooks is to interchange\n", "the order of `W` and `b` such that the batch dimension is \"up front\" and\n", "won't be in the way\n", "\n", "$$ y = xW^T + b $$\n", "\n", "which with indices looks like\n", "\n", "$$ y = x_{bj}W^T_{ji} + b_{i} = W_{ij}x_{bj} + b_{i}$$\n", "\n", "\n", "This works, but also at the cost of changing code to accomodate \"batching\"\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "0f974f37", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.90000004, 1.53 ],\n", " [1.5000001 , 2.6299999 ],\n", " [2.1000001 , 3.7299998 ]], dtype=float32)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def transpose_layer(x):\n", " W = jnp.array([[0.2,0.4],[0.9,0.2]])\n", " b = jnp.array([0.2,0.3])\n", " return x.dot(W.T) + b\n", "\n", "\n", "transpose_layer(inputs)" ] }, { "cell_type": "markdown", "id": "7cfa1b16", "metadata": {}, "source": [ "## The JAX Way:\n", "\n", "In JAX, batching could not be easier!\n", "\n", "There is no change in the code, but you just apply a \n", "batching function to it to receive a \"batched\" version of it" ] }, { "cell_type": "code", "execution_count": 48, "id": "f9371639", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.90000004, 1.53 ],\n", " [1.5000001 , 2.6299999 ],\n", " [2.1000001 , 3.7299998 ]], dtype=float32)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batched_dense = jax.vmap(dense_layer)\n", "batched_dense(inputs)" ] }, { "cell_type": "markdown", "id": "320419ef", "metadata": {}, "source": [ "# Handling multiples inputs\n", "\n", "Sometimes you have more than one inputs to a function \n", "\n", "$$ f = f(x,y) $$\n", "\n", "and for such functions some more information must be provided\n", "as to how batching should take place" ] }, { "cell_type": "markdown", "id": "69a40d36", "metadata": {}, "source": [ "## Zipping\n", "\n", "When zipping a function we want both arguments to be iterated \"in lock-step\"\n", "\n", "In standard Python it would look something like this:\n", "\n", "`out = [f(x[i],x[i]) for zip(x,y)]`\n", "\n", "In JAX, we give that batch-dimension of each argument (here: 0) via the `in_axes` argument" ] }, { "cell_type": "code", "execution_count": 71, "id": "03fdf36a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([5, 7, 9], dtype=int32)" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def func(x,y):\n", " return x+y\n", "\n", "x_batched = jnp.array([1,2,3])\n", "y_batched = jnp.array([4,5,6])\n", "jax.vmap(func, in_axes = (0,0))(x_batched,y_batched)" ] }, { "cell_type": "markdown", "id": "a3635cc4", "metadata": {}, "source": [ "## Non-leading Batch dimensions\n", "\n", "It can happen that your batch dimension for the various arguments differ\n", "\n", "for example for\n", "\n", "```\n", "x = [[1,2,3]] #shape (1,3)\n", "y = [[4],[5],[6] #shape (3,1)\n", "```\n", "\n", "In this case our Python code would look like\n", "\n", "`out = [f(x[0][i],x[i][0]) for zip(x,y)]`\n", "\n", "That is for `x` the batch dimension is the second (i.e. idx = 1, starting from 0) dimension\n", "of the array.\n", "\n", "For `y` it's the first dimension of the array (i.e. idx=0)\n", "\n", "We can communicate this to JAX by specifying where the batch dimensions are\n", "for the inputs and also where we want the batch dimension to be for the output" ] }, { "cell_type": "code", "execution_count": 72, "id": "18c04558", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5 7 9]]\n", "(1, 3)\n" ] } ], "source": [ "x_batched = jnp.array([[1,2,3]])\n", "y_batched = jnp.array([[4],[5],[6]])\n", "out1 = jax.vmap(func, in_axes = (1,0), out_axes=1)(x_batched,y_batched)\n", "print(out1)\n", "print(out1.shape) #batch dimension of size 3 at second position (idx 1)" ] }, { "cell_type": "code", "execution_count": 73, "id": "889ab85c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5]\n", " [7]\n", " [9]]\n", "(3, 1)\n" ] } ], "source": [ "out2 = jax.vmap(func, in_axes = (1,0), out_axes=0)(x_batched,y_batched)\n", "print(out2)\n", "print(out2.shape) #batch dimension of size 3 at first position (idx 0)" ] }, { "cell_type": "markdown", "id": "4a1ca4da", "metadata": {}, "source": [ "## Un-batched dimensions\n", "\n", "You can also have un-batched dimensions that just \"go along for the ride\"\n", "\n", "e.g.\n", "\n", "```\n", "x = 3.0\n", "out = [f(x,y) for y in range(..)]\n", "```\n", "\n", "In JAX by passing `None`" ] }, { "cell_type": "code", "execution_count": 77, "id": "c4a7a0c8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([ 9., 12., 15.], dtype=float32, weak_type=True)" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def func(x,y):\n", " return x*y\n", "\n", "jax.vmap(func, in_axes = (None,0))(3.0,jnp.array([3,4,5]))" ] }, { "cell_type": "markdown", "id": "2d241761", "metadata": {}, "source": [ "## Composition and Cartesian Products\n", "\n", "For a function $f(x,y)$ you may want to evalute it on a \"grid\"\n", "\n", "\n", "```\n", "x = 3.0\n", "out = [f(x[i],y[j]) for i,j in cartesian_product(...)]\n", "```\n", "\n", "by applying `vmap` multiple times" ] }, { "cell_type": "code", "execution_count": 99, "id": "d645f5a2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x_range = jnp.linspace(-5,5)\n", "y_range = jnp.linspace(-2,2)\n", "\n", "def func(x,y):\n", " return jnp.sin(x)*y\n", "\n", "out = jax.vmap(jax.vmap(func, in_axes = (None,0)), in_axes = (0,None))(x_range,y_range)\n", "\n", "grid = jnp.meshgrid(x_range,y_range, indexing = 'ij')\n", "plt.contourf(grid[0],grid[1],out, levels = 31)" ] }, { "cell_type": "code", "execution_count": null, "id": "67f77afd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.11" } }, "nbformat": 4, "nbformat_minor": 5 }