# Solving the wave equation on cloud TPUs

[_Stephan Hoyer_](https://twitter.com/shoyer)

In this notebook, we solve the 2D [wave equation](https://en.wikipedia.org/wiki/Wave_equation):
$$
\frac{\partial^2 u}{\partial t^2} = c^2 \nabla^2 u
$$

We use a simple [finite difference](https://en.wikipedia.org/wiki/Finite_difference_method) formulation with [Leapfrog time integration](https://en.wikipedia.org/wiki/Leapfrog_integration).

Note: It is natural to express finite difference methods as convolutions, but here we intentionally avoid convolutions in favor of array indexing/arithmetic. This is because "batch" and "feature" dimensions in TPU convolutions are padded to multiples of either 8 and 128, but in our case both these dimensions are effectively of size 1.


## Setup required environment

In [0]:
# Grab other packages for this demo.
!pip install -U -q Pillow moviepy proglog scikit-image

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
 url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
 resp = requests.post(url)
 TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

## Simulation code

In [0]:
from functools import partial
import jax
from jax import jit, pmap
from jax import lax
from jax import tree_util
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import skimage.filters
import proglog
from moviepy.editor import ImageSequenceClip

device_count = jax.device_count()

# Spatial partitioning via halo exchange

def send_right(x, axis_name):
 # Note: if some devices are omitted from the permutation, lax.ppermute
 # provides zeros instead. This gives us an easy way to apply Dirichlet
 # boundary conditions.
 left_perm = [(i, (i + 1) % device_count) for i in range(device_count - 1)]
 return lax.ppermute(x, perm=left_perm, axis_name=axis_name)

def send_left(x, axis_name):
 left_perm = [((i + 1) % device_count, i) for i in range(device_count - 1)]
 return lax.ppermute(x, perm=left_perm, axis_name=axis_name)

def axis_slice(ndim, index, axis):
 slices = [slice(None)] * ndim
 slices[axis] = index
 return tuple(slices)

def slice_along_axis(array, index, axis):
 return array[axis_slice(array.ndim, index, axis)]

def tree_vectorize(func):
 def wrapper(x, *args, **kwargs):
 return tree_util.tree_map(lambda x: func(x, *args, **kwargs), x)
 return wrapper

@tree_vectorize
def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):
 if not padding > 0:
 raise ValueError(f'invalid padding: {padding}')
 array = jnp.array(array)
 if array.ndim == 0:
 return array
 left = slice_along_axis(array, slice(None, padding), axis)
 right = slice_along_axis(array, slice(-padding, None), axis)
 right, left = send_left(left, axis_name), send_right(right, axis_name)
 return jnp.concatenate([left, array, right], axis)

@tree_vectorize
def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):
 left = slice_along_axis(array, slice(padding, 2*padding), axis)
 right = slice_along_axis(array, slice(-2*padding, -padding), axis)
 right, left = send_left(left, axis_name), send_right(right, axis_name)
 array = jax.ops.index_update(
 array, axis_slice(array.ndim, slice(None, padding), axis), left)
 array = jax.ops.index_update(
 array, axis_slice(array.ndim, slice(-padding, None), axis), right)
 return array

# Reshaping inputs/outputs for pmap

def split_with_reshape(array, num_splits, *, split_axis=0, tile_id_axis=None):
 if tile_id_axis is None:
 tile_id_axis = split_axis
 tile_size, remainder = divmod(array.shape[split_axis], num_splits)
 if remainder:
 raise ValueError('num_splits must equally divide the dimension size')
 new_shape = list(array.shape)
 new_shape[split_axis] = tile_size
 new_shape.insert(split_axis, num_splits)
 return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)

def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):
 if tile_id_axis is None:
 tile_id_axis = split_axis
 array = jnp.moveaxis(array, tile_id_axis, split_axis)
 new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]
 return jnp.reshape(array, new_shape)

def shard(func):
 def wrapper(state):
 sharded_state = tree_util.tree_map(
 lambda x: split_with_reshape(x, device_count), state)
 sharded_result = func(sharded_state)
 result = tree_util.tree_map(stack_with_reshape, sharded_result)
 return result
 return wrapper

# Physics

def shift(array, offset, axis):
 index = slice(offset, None) if offset >= 0 else slice(None, offset)
 sliced = slice_along_axis(array, index, axis)
 padding = [(0, 0)] * array.ndim
 padding[axis] = (-min(offset, 0), max(offset, 0))
 return jnp.pad(sliced, padding, mode='constant', constant_values=0)

def laplacian(array, step=1):
 left = shift(array, +1, axis=0)
 right = shift(array, -1, axis=0)
 up = shift(array, +1, axis=1)
 down = shift(array, -1, axis=1)
 convolved = (left + right + up + down - 4 * array)
 if step != 1:
 convolved *= (1 / step ** 2)
 return convolved

def scalar_wave_equation(u, c=1, dx=1):
 return c ** 2 * laplacian(u, dx)

@jax.jit
def leapfrog_step(state, dt=0.5, c=1):
 # https://en.wikipedia.org/wiki/Leapfrog_integration
 u, u_t = state
 u_tt = scalar_wave_equation(u, c)
 u_t = u_t + u_tt * dt
 u = u + u_t * dt
 return (u, u_t)

# Time stepping

def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):
 return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)

def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,
 save_interval=1):

 def exchange_and_multi_step(state_padded):
 c_padded = halo_exchange_padding(c, exchange_interval)
 evolved = multi_step(state_padded, exchange_interval, dt, c_padded)
 return halo_exchange_inplace(evolved, exchange_interval)

 @shard
 @partial(jax.pmap, axis_name='x')
 def simulate_until_output(state):
 stop = save_interval // exchange_interval
 state_padded = halo_exchange_padding(state, exchange_interval)
 advanced = lax.fori_loop(
 0, stop, lambda i, s: exchange_and_multi_step(s), state_padded)
 xi = exchange_interval
 return tree_util.tree_map(lambda array: array[xi:-xi, ...], advanced)

 results = [state]
 for _ in range(count // save_interval):
 state = simulate_until_output(state)
 tree_util.tree_map(lambda x: x.copy_to_host_async(), state)
 results.append(state)
 results = jax.device_get(results)
 return tree_util.tree_multimap(lambda *xs: np.stack([np.array(x) for x in xs]), *results)

multi_step_jit = jax.jit(multi_step)

## Initial conditions

In [0]:
x = jnp.linspace(0, 8, num=8*1024, endpoint=False)
y = jnp.linspace(0, 1, num=1*1024, endpoint=False)
x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')

# NOTE: smooth initial conditions are important, so we aren't exciting
# arbitrarily high frequencies (that cannot be resolved)
u = skimage.filters.gaussian(
 ((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,
 sigma=1)

# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)

# u = skimage.filters.gaussian(
# (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),
# sigma=5)

v = jnp.zeros_like(u)
c = 1 # could also use a 2D array matching the mesh shape

In [0]:
u.shape

## Test scaling from 1 to 8 chips

In [0]:
%%time
# single TPU chip
u_final, _ = multi_step_jit((u, v), count=2**13, c=c, dt=0.5)

In [0]:
%%time
# 8x TPU chips, 4x more steps in roughly half the time!
u_final, _ = multi_step_pmap(
 (u, v), count=2**15, c=c, dt=0.5, exchange_interval=4, save_interval=2**15)

In [0]:
18.3 / (10.3 / 4) # near linear scaling (8x would be perfect)

## Save a bunch of outputs for a movie

In [0]:
%%time
# save more outputs for a movie -- this is slow!
u_final, _ = multi_step_pmap(
 (u, v), count=2**15, c=c, dt=0.2, exchange_interval=4, save_interval=2**10)

In [0]:
u_final.shape

In [0]:
u_final.nbytes / 1e9

In [0]:
plt.figure(figsize=(18, 6))
plt.axis('off')
plt.imshow(u_final[-1].T, cmap='RdBu');

In [0]:
fig, axes = plt.subplots(9, 1, figsize=(14, 14))
[ax.axis('off') for ax in axes]
axes[0].imshow(u_final[0].T, cmap='RdBu', aspect='equal', vmin=-1, vmax=1)
for i in range(8):
 axes[i+1].imshow(u_final[4*i+1].T / abs(u_final[4*i+1]).max(), cmap='RdBu', aspect='equal', vmin=-1, vmax=1)

In [0]:
import matplotlib.cm
import matplotlib.colors
from PIL import Image

def make_images(data, cmap='RdBu', vmax=None):
 images = []
 for frame in data:
 if vmax is None:
 this_vmax = np.max(abs(frame))
 else:
 this_vmax = vmax
 norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)
 mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
 rgba = mappable.to_rgba(frame, bytes=True)
 image = Image.fromarray(rgba, mode='RGBA')
 images.append(image)
 return images

def save_movie(images, path, duration=100, loop=0, **kwargs):
 images[0].save(path, save_all=True, append_images=images[1:],
 duration=duration, loop=loop, **kwargs)

images = make_images(u_final[::, ::8, ::8].transpose(0, 2, 1))

In [0]:
# Show Movie
proglog.default_bar_logger = partial(proglog.default_bar_logger, None)
ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()

In [0]:
# Save GIF.
save_movie(images,'wave_movie.gif', duration=[2000]+[200]*(len(images)-2)+[2000])
# The movie sometimes takes a second before showing up in the file system.
import time; time.sleep(1)

In [0]:
# Download animation.
try:
 from google.colab import files
except ImportError:
 pass
else:
 files.download('wave_movie.gif')