{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "l7Sh70ascFNs" }, "source": [ "# Solving the wave equation on cloud TPUs\n", "\n", "[_Stephan Hoyer_](https://twitter.com/shoyer)\n", "\n", "In this notebook, we solve the 2D [wave equation](https://en.wikipedia.org/wiki/Wave_equation):\n", "$$\n", "\\frac{\\partial^2 u}{\\partial t^2} = c^2 \\nabla^2 u\n", "$$\n", "\n", "We use a simple [finite difference](https://en.wikipedia.org/wiki/Finite_difference_method) formulation with [Leapfrog time integration](https://en.wikipedia.org/wiki/Leapfrog_integration).\n", "\n", "Note: It is natural to express finite difference methods as convolutions, but here we intentionally avoid convolutions in favor of array indexing/arithmetic. This is because \"batch\" and \"feature\" dimensions in TPU convolutions are padded to multiples of either 8 and 128, but in our case both these dimensions are effectively of size 1.\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xAd8PvW6ceyk" }, "source": [ "## Setup required environment" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "n6L8zIaAIrqj" }, "outputs": [], "source": [ "# Grab other packages for this demo.\n", "!pip install -U -q Pillow moviepy proglog scikit-image\n", "\n", "# Make sure the Colab Runtime is set to Accelerator: TPU.\n", "import requests\n", "import os\n", "if 'TPU_DRIVER_MODE' not in globals():\n", " url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'\n", " resp = requests.post(url)\n", " TPU_DRIVER_MODE = 1\n", "\n", "# The following is required to use TPU Driver as JAX's backend.\n", "from jax.config import config\n", "config.FLAGS.jax_xla_backend = \"tpu_driver\"\n", "config.FLAGS.jax_backend_target = \"grpc://\" + os.environ['COLAB_TPU_ADDR']\n", "print(config.FLAGS.jax_backend_target)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UHkzcAMhcpm_" }, "source": [ "## Simulation code" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "4NJQ8Q5M99wy" }, "outputs": [], "source": [ "from functools import partial\n", "import jax\n", "from jax import jit, pmap\n", "from jax import lax\n", "from jax import tree_util\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import skimage.filters\n", "import proglog\n", "from moviepy.editor import ImageSequenceClip\n", "\n", "device_count = jax.device_count()\n", "\n", "# Spatial partitioning via halo exchange\n", "\n", "def send_right(x, axis_name):\n", " # Note: if some devices are omitted from the permutation, lax.ppermute\n", " # provides zeros instead. This gives us an easy way to apply Dirichlet\n", " # boundary conditions.\n", " left_perm = [(i, (i + 1) % device_count) for i in range(device_count - 1)]\n", " return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n", "\n", "def send_left(x, axis_name):\n", " left_perm = [((i + 1) % device_count, i) for i in range(device_count - 1)]\n", " return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n", "\n", "def axis_slice(ndim, index, axis):\n", " slices = [slice(None)] * ndim\n", " slices[axis] = index\n", " return tuple(slices)\n", "\n", "def slice_along_axis(array, index, axis):\n", " return array[axis_slice(array.ndim, index, axis)]\n", "\n", "def tree_vectorize(func):\n", " def wrapper(x, *args, **kwargs):\n", " return tree_util.tree_map(lambda x: func(x, *args, **kwargs), x)\n", " return wrapper\n", "\n", "@tree_vectorize\n", "def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):\n", " if not padding > 0:\n", " raise ValueError(f'invalid padding: {padding}')\n", " array = jnp.array(array)\n", " if array.ndim == 0:\n", " return array\n", " left = slice_along_axis(array, slice(None, padding), axis)\n", " right = slice_along_axis(array, slice(-padding, None), axis)\n", " right, left = send_left(left, axis_name), send_right(right, axis_name)\n", " return jnp.concatenate([left, array, right], axis)\n", "\n", "@tree_vectorize\n", "def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):\n", " left = slice_along_axis(array, slice(padding, 2*padding), axis)\n", " right = slice_along_axis(array, slice(-2*padding, -padding), axis)\n", " right, left = send_left(left, axis_name), send_right(right, axis_name)\n", " array = jax.ops.index_update(\n", " array, axis_slice(array.ndim, slice(None, padding), axis), left)\n", " array = jax.ops.index_update(\n", " array, axis_slice(array.ndim, slice(-padding, None), axis), right)\n", " return array\n", "\n", "# Reshaping inputs/outputs for pmap\n", "\n", "def split_with_reshape(array, num_splits, *, split_axis=0, tile_id_axis=None):\n", " if tile_id_axis is None:\n", " tile_id_axis = split_axis\n", " tile_size, remainder = divmod(array.shape[split_axis], num_splits)\n", " if remainder:\n", " raise ValueError('num_splits must equally divide the dimension size')\n", " new_shape = list(array.shape)\n", " new_shape[split_axis] = tile_size\n", " new_shape.insert(split_axis, num_splits)\n", " return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)\n", "\n", "def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):\n", " if tile_id_axis is None:\n", " tile_id_axis = split_axis\n", " array = jnp.moveaxis(array, tile_id_axis, split_axis)\n", " new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]\n", " return jnp.reshape(array, new_shape)\n", "\n", "def shard(func):\n", " def wrapper(state):\n", " sharded_state = tree_util.tree_map(\n", " lambda x: split_with_reshape(x, device_count), state)\n", " sharded_result = func(sharded_state)\n", " result = tree_util.tree_map(stack_with_reshape, sharded_result)\n", " return result\n", " return wrapper\n", "\n", "# Physics\n", "\n", "def shift(array, offset, axis):\n", " index = slice(offset, None) if offset >= 0 else slice(None, offset)\n", " sliced = slice_along_axis(array, index, axis)\n", " padding = [(0, 0)] * array.ndim\n", " padding[axis] = (-min(offset, 0), max(offset, 0))\n", " return jnp.pad(sliced, padding, mode='constant', constant_values=0)\n", "\n", "def laplacian(array, step=1):\n", " left = shift(array, +1, axis=0)\n", " right = shift(array, -1, axis=0)\n", " up = shift(array, +1, axis=1)\n", " down = shift(array, -1, axis=1)\n", " convolved = (left + right + up + down - 4 * array)\n", " if step != 1:\n", " convolved *= (1 / step ** 2)\n", " return convolved\n", "\n", "def scalar_wave_equation(u, c=1, dx=1):\n", " return c ** 2 * laplacian(u, dx)\n", "\n", "@jax.jit\n", "def leapfrog_step(state, dt=0.5, c=1):\n", " # https://en.wikipedia.org/wiki/Leapfrog_integration\n", " u, u_t = state\n", " u_tt = scalar_wave_equation(u, c)\n", " u_t = u_t + u_tt * dt\n", " u = u + u_t * dt\n", " return (u, u_t)\n", "\n", "# Time stepping\n", "\n", "def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):\n", " return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)\n", "\n", "def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,\n", " save_interval=1):\n", "\n", " def exchange_and_multi_step(state_padded):\n", " c_padded = halo_exchange_padding(c, exchange_interval)\n", " evolved = multi_step(state_padded, exchange_interval, dt, c_padded)\n", " return halo_exchange_inplace(evolved, exchange_interval)\n", "\n", " @shard\n", " @partial(jax.pmap, axis_name='x')\n", " def simulate_until_output(state):\n", " stop = save_interval // exchange_interval\n", " state_padded = halo_exchange_padding(state, exchange_interval)\n", " advanced = lax.fori_loop(\n", " 0, stop, lambda i, s: exchange_and_multi_step(s), state_padded)\n", " xi = exchange_interval\n", " return tree_util.tree_map(lambda array: array[xi:-xi, ...], advanced)\n", "\n", " results = [state]\n", " for _ in range(count // save_interval):\n", " state = simulate_until_output(state)\n", " tree_util.tree_map(lambda x: x.copy_to_host_async(), state)\n", " results.append(state)\n", " results = jax.device_get(results)\n", " return tree_util.tree_multimap(lambda *xs: np.stack([np.array(x) for x in xs]), *results)\n", "\n", "multi_step_jit = jax.jit(multi_step)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_kMyXbkEeCu3" }, "source": [ "## Initial conditions" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "usWkf2UgAYc5" }, "outputs": [], "source": [ "x = jnp.linspace(0, 8, num=8*1024, endpoint=False)\n", "y = jnp.linspace(0, 1, num=1*1024, endpoint=False)\n", "x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')\n", "\n", "# NOTE: smooth initial conditions are important, so we aren't exciting\n", "# arbitrarily high frequencies (that cannot be resolved)\n", "u = skimage.filters.gaussian(\n", " ((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,\n", " sigma=1)\n", "\n", "# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n", "\n", "# u = skimage.filters.gaussian(\n", "# (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),\n", "# sigma=5)\n", "\n", "v = jnp.zeros_like(u)\n", "c = 1 # could also use a 2D array matching the mesh shape" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "jWtIvDTfCx12" }, "outputs": [], "source": [ "u.shape" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RVdPIRKmeNX3" }, "source": [ "## Test scaling from 1 to 8 chips" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "SPG6jvYSCaKQ" }, "outputs": [], "source": [ "%%time\n", "# single TPU chip\n", "u_final, _ = multi_step_jit((u, v), count=2**13, c=c, dt=0.5)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "MpsDzNyC6OI0" }, "outputs": [], "source": [ "%%time\n", "# 8x TPU chips, 4x more steps in roughly half the time!\n", "u_final, _ = multi_step_pmap(\n", " (u, v), count=2**15, c=c, dt=0.5, exchange_interval=4, save_interval=2**15)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "fPCwJ51FeBLR" }, "outputs": [], "source": [ "18.3 / (10.3 / 4) # near linear scaling (8x would be perfect)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jYMn69z8eQ7O" }, "source": [ "## Save a bunch of outputs for a movie" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Uns4X34EPYgF" }, "outputs": [], "source": [ "%%time\n", "# save more outputs for a movie -- this is slow!\n", "u_final, _ = multi_step_pmap(\n", " (u, v), count=2**15, c=c, dt=0.2, exchange_interval=4, save_interval=2**10)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "Hsjuk22Nbe9Z" }, "outputs": [], "source": [ "u_final.shape" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "LXqmz-CCdQjt" }, "outputs": [], "source": [ "u_final.nbytes / 1e9" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "cPvXrPUCPPtt" }, "outputs": [], "source": [ "plt.figure(figsize=(18, 6))\n", "plt.axis('off')\n", "plt.imshow(u_final[-1].T, cmap='RdBu');" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "uqxGmttmgSsN" }, "outputs": [], "source": [ "fig, axes = plt.subplots(9, 1, figsize=(14, 14))\n", "[ax.axis('off') for ax in axes]\n", "axes[0].imshow(u_final[0].T, cmap='RdBu', aspect='equal', vmin=-1, vmax=1)\n", "for i in range(8):\n", " axes[i+1].imshow(u_final[4*i+1].T / abs(u_final[4*i+1]).max(), cmap='RdBu', aspect='equal', vmin=-1, vmax=1)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "hNZrt590p6zr" }, "outputs": [], "source": [ "import matplotlib.cm\n", "import matplotlib.colors\n", "from PIL import Image\n", "\n", "def make_images(data, cmap='RdBu', vmax=None):\n", " images = []\n", " for frame in data:\n", " if vmax is None:\n", " this_vmax = np.max(abs(frame))\n", " else:\n", " this_vmax = vmax\n", " norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)\n", " mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)\n", " rgba = mappable.to_rgba(frame, bytes=True)\n", " image = Image.fromarray(rgba, mode='RGBA')\n", " images.append(image)\n", " return images\n", "\n", "def save_movie(images, path, duration=100, loop=0, **kwargs):\n", " images[0].save(path, save_all=True, append_images=images[1:],\n", " duration=duration, loop=loop, **kwargs)\n", "\n", "images = make_images(u_final[::, ::8, ::8].transpose(0, 2, 1))" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "ZWEZWehFjboa" }, "outputs": [], "source": [ "# Show Movie\n", "proglog.default_bar_logger = partial(proglog.default_bar_logger, None)\n", "ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "whG0OOepezuX" }, "outputs": [], "source": [ "# Save GIF.\n", "save_movie(images,'wave_movie.gif', duration=[2000]+[200]*(len(images)-2)+[2000])\n", "# The movie sometimes takes a second before showing up in the file system.\n", "import time; time.sleep(1)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "E1V7aS63vRAT" }, "outputs": [], "source": [ "# Download animation.\n", "try:\n", " from google.colab import files\n", "except ImportError:\n", " pass\n", "else:\n", " files.download('wave_movie.gif')" ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "name": "JAX TPU wave equation", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 1 }