{ "cells": [ { "cell_type": "markdown", "id": "03617dd0-ce8d-4c5f-a9e5-edb8395c21b2", "metadata": {}, "source": [ "# Nonlinear heat PDE\n", "\n", "Diffrax can also be used to solve some PDEs.\n", "\n", "(Specifically, the scope of Diffrax is \"any numerical method which iterates over timesteps\". This means that e.g. semidiscretised evolution equations are in-scope, but e.g. finite volume methods for elliptic equations are out-of-scope.)\n", "\n", "---\n", "\n", "In this example, we solve the nonlinear heat equation\n", "\n", "$$ \\frac{\\partial y}{\\partial t}(t, x) = (1 - y(t, x)) \\Delta y(t, x) \\qquad\\text{in}\\qquad t \\in [0, 40], x \\in [-1, 1]$$\n", "\n", "subject to the initial condition\n", "\n", "$$ y(0, x) = x^2, $$\n", "\n", "and Dirichlet boundary conditions\n", "\n", "$$ y(t, -1) = 1,\\qquad y(t, 1) = 1. $$\n", "\n", "---\n", "\n", "We spatially discretise $x \\in [-1, 1]$ into points $-1 = x_0 < x_1 < \\cdots < x_{n-1} = 1$, with equal spacing $\\delta x = x_{i+1} - x_i$. The solution is then discretised into $y(t, x_i) \\approx y_i(t)$, and the Laplacian discretised into $\\Delta y(t,x_i) \\approx \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{\\delta x^2}$.\n", "\n", "In doing so we reduce to a system of ODEs\n", "\n", "$$ \\frac{\\mathrm{d}y_i}{\\mathrm{d}t}(t) = (1 - y_i(t)) \\frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{\\delta x^2} \\qquad\\text{for}\\qquad i \\in \\{1, ..., n-2\\},$$\n", "\n", "subject to the initial condition\n", "\n", "$$ y_i(0) = {x_i}^2, $$\n", "\n", "for which the Dirichlet boundary conditions become\n", "\n", "$$ \\frac{\\mathrm{d}y_0}{\\mathrm{d}t}(t) = 0,\\qquad \\frac{\\mathrm{d}y_{n-1}}{\\mathrm{d}t}(t) = 0. $$\n", "\n", "---\n", "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/nonlinear_heat_pde.ipynb).\n", "\n", "\n", "!!! danger \"Advanced example\"\n", "\n", " This is an advanced example, as it involves defining a custom solver." ] }, { "cell_type": "code", "execution_count": 1, "id": "0a89f429-bab4-4a0f-800c-a0c8e1c7bf9b", "metadata": { "tags": [] }, "outputs": [], "source": [ "from typing import Callable\n", "\n", "import diffrax\n", "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", "import jax\n", "import jax.lax as lax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "from jaxtyping import Array, Float # https://github.com/google/jaxtyping\n", "\n", "\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": 2, "id": "16da14af-420a-4d25-aa06-515a9baa50c2", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Represents the interval [x0, x_final] discretised into n equally-spaced points.\n", "class SpatialDiscretisation(eqx.Module):\n", " x0: float = eqx.field(static=True)\n", " x_final: float = eqx.field(static=True)\n", " vals: Float[Array, \"n\"]\n", "\n", " @classmethod\n", " def discretise_fn(cls, x0: float, x_final: float, n: int, fn: Callable):\n", " if n < 2:\n", " raise ValueError(\"Must discretise [x0, x_final] into at least two points\")\n", " vals = jax.vmap(fn)(jnp.linspace(x0, x_final, n))\n", " return cls(x0, x_final, vals)\n", "\n", " @property\n", " def δx(self):\n", " return (self.x_final - self.x0) / (len(self.vals) - 1)\n", "\n", " def binop(self, other, fn):\n", " if isinstance(other, SpatialDiscretisation):\n", " if self.x0 != other.x0 or self.x_final != other.x_final:\n", " raise ValueError(\"Mismatched spatial discretisations\")\n", " other = other.vals\n", " return SpatialDiscretisation(self.x0, self.x_final, fn(self.vals, other))\n", "\n", " def __add__(self, other):\n", " return self.binop(other, lambda x, y: x + y)\n", "\n", " def __mul__(self, other):\n", " return self.binop(other, lambda x, y: x * y)\n", "\n", " def __radd__(self, other):\n", " return self.binop(other, lambda x, y: y + x)\n", "\n", " def __rmul__(self, other):\n", " return self.binop(other, lambda x, y: y * x)\n", "\n", " def __sub__(self, other):\n", " return self.binop(other, lambda x, y: x - y)\n", "\n", " def __rsub__(self, other):\n", " return self.binop(other, lambda x, y: y - x)\n", "\n", "\n", "def laplacian(y: SpatialDiscretisation) -> SpatialDiscretisation:\n", " y_next = jnp.roll(y.vals, shift=1)\n", " y_prev = jnp.roll(y.vals, shift=-1)\n", " Δy = (y_next - 2 * y.vals + y_prev) / (y.δx**2)\n", " # Dirichlet boundary condition\n", " Δy = Δy.at[0].set(0)\n", " Δy = Δy.at[-1].set(0)\n", " return SpatialDiscretisation(y.x0, y.x_final, Δy)" ] }, { "cell_type": "markdown", "id": "7482e079-5ed1-4bc7-85a5-dcee3717ce7f", "metadata": {}, "source": [ "First let's try solving this semidiscretisation directly, as a system of ODEs." ] }, { "cell_type": "code", "execution_count": 3, "id": "d304a9d0-58c7-4d29-91e6-10bc65406b73", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Problem\n", "def vector_field(t, y, args):\n", " return (1 - y) * laplacian(y)\n", "\n", "\n", "term = diffrax.ODETerm(vector_field)\n", "ic = lambda x: x**2\n", "\n", "# Spatial discretisation\n", "x0 = -1\n", "x_final = 1\n", "n = 50\n", "y0 = SpatialDiscretisation.discretise_fn(x0, x_final, n, ic)\n", "\n", "# Temporal discretisation\n", "t0 = 0\n", "t_final = 1\n", "δt = 0.0001\n", "saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 50))\n", "\n", "# Tolerances\n", "rtol = 1e-10\n", "atol = 1e-10\n", "stepsize_controller = diffrax.PIDController(\n", " pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol, dtmax=0.001\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "d1ad0404-5a13-4506-bdab-cdcfaf5be609", "metadata": { "tags": [] }, "outputs": [], "source": [ "solver = diffrax.Tsit5()\n", "sol = diffrax.diffeqsolve(\n", " term,\n", " solver,\n", " t0,\n", " t_final,\n", " δt,\n", " y0,\n", " saveat=saveat,\n", " stepsize_controller=stepsize_controller,\n", " max_steps=None,\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "28185196-75f2-4465-ad59-ff45ec8c4d01", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5, 5))\n", "plt.imshow(\n", " sol.ys.vals,\n", " origin=\"lower\",\n", " extent=(x0, x_final, t0, t_final),\n", " aspect=(x_final - x0) / (t_final - t0),\n", " cmap=\"inferno\",\n", ")\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"t\", rotation=0)\n", "plt.clim(0, 1)\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "26ba8fec-3ca9-4612-b6f9-83962333d96d", "metadata": {}, "source": [ "That worked!\n", "\n", "However, for more complicated PDEs then we may wish to define a custom solver. So as an example, here's how to solve the same PDE using the famous [Crank–Nicolson](https://en.wikipedia.org/wiki/Crank%E2%80%93Nicolson_method) scheme.\n", "\n", "(See the page on [abstract solvers](https://docs.kidger.site/diffrax/api/solvers/abstract_solvers/) for more details about how to define a custom solver.)" ] }, { "cell_type": "code", "execution_count": 6, "id": "059fed69-c042-4fec-bf36-60e365c98de8", "metadata": { "tags": [] }, "outputs": [], "source": [ "class CrankNicolson(diffrax.AbstractSolver):\n", " rtol: float\n", " atol: float\n", "\n", " term_structure = diffrax.ODETerm\n", " interpolation_cls = diffrax.LocalLinearInterpolation\n", "\n", " def order(self, terms):\n", " return 2\n", "\n", " def init(self, terms, t0, t1, y0, args):\n", " return None\n", "\n", " def step(self, terms, t0, t1, y0, args, solver_state, made_jump):\n", " del solver_state, made_jump\n", " δt = t1 - t0\n", " f0 = terms.vf(t0, y0, args)\n", "\n", " def keep_iterating(val):\n", " _, not_converged = val\n", " return not_converged\n", "\n", " def fixed_point_iteration(val):\n", " y1, _ = val\n", " new_y1 = y0 + 0.5 * δt * (f0 + terms.vf(t1, y1, args))\n", " diff = jnp.abs((new_y1 - y1).vals)\n", " max_y1 = jnp.maximum(jnp.abs(y1.vals), jnp.abs(new_y1.vals))\n", " scale = self.atol + self.rtol * max_y1\n", " not_converged = jnp.any(diff > scale)\n", " return new_y1, not_converged\n", "\n", " euler_y1 = y0 + δt * f0\n", " y1, _ = lax.while_loop(keep_iterating, fixed_point_iteration, (euler_y1, False))\n", "\n", " y_error = y1 - euler_y1\n", " dense_info = dict(y0=y0, y1=y1)\n", "\n", " solver_state = None\n", " result = diffrax.RESULTS.successful\n", " return y1, y_error, dense_info, solver_state, result\n", "\n", " def func(self, terms, t0, y0, args):\n", " return terms.vf(t0, y0, args)" ] }, { "cell_type": "code", "execution_count": 7, "id": "da4511b8-f112-4839-94f5-dfc7728da8ea", "metadata": { "tags": [] }, "outputs": [], "source": [ "solver = CrankNicolson(rtol=rtol, atol=atol)\n", "sol = diffrax.diffeqsolve(\n", " term,\n", " solver,\n", " t0,\n", " t_final,\n", " δt,\n", " y0,\n", " saveat=saveat,\n", " stepsize_controller=stepsize_controller,\n", " max_steps=None,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "6667e3c7-5b45-4740-9caf-3e0aa4b1d7a9", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5, 5))\n", "plt.imshow(\n", " sol.ys.vals,\n", " origin=\"lower\",\n", " extent=(x0, x_final, t0, t_final),\n", " aspect=(x_final - x0) / (t_final - t0),\n", " cmap=\"inferno\",\n", ")\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"t\", rotation=0)\n", "plt.clim(0, 1)\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "b4b4ced9-0602-4354-a1b9-277ddf70245c", "metadata": {}, "source": [ "Some final notes.\n", "\n", "1. We wrote down the general Crank–Nicolson method, which uses a fixed point iteration to solve the implicit problem. If you know something about the structure of your problem (e.g. that it is linear) then it is often possible to more specialised solvers, which run faster. (E.g. linear solvers.)\n", "\n", "2. To keep this example brief, we didn't worry about doing a von Neumann stability analysis." ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }