{ "cells": [ { "cell_type": "markdown", "id": "1fe536ed", "metadata": {}, "source": [ "# Computing second-order sensitivities" ] }, { "cell_type": "markdown", "id": "598ab169-05d8-4733-a6cc-9fa91aa92198", "metadata": {}, "source": [ "This example demonstrates how to compute the Hessian of a differential equation solve.\n", "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/hessian.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "6d6bdf63", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "((Array(3.9131193, dtype=float32, weak_type=True),\n", " Array(-2.374867, dtype=float32, weak_type=True)),\n", " (Array(-2.3748531, dtype=float32, weak_type=True),\n", " Array(1.688472, dtype=float32, weak_type=True)))" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from diffrax import diffeqsolve, ODETerm, Tsit5\n", "\n", "\n", "def vector_field(t, y, args):\n", " prey, predator = y\n", " α, β, γ, δ = args\n", " d_prey = α * prey - β * prey * predator\n", " d_predator = -γ * predator + δ * prey * predator\n", " d_y = d_prey, d_predator\n", " return d_y\n", "\n", "\n", "@jax.jit\n", "@jax.hessian\n", "def run(y0):\n", " term = ODETerm(vector_field)\n", " solver = Tsit5(scan_kind=\"bounded\")\n", " t0 = 0\n", " t1 = 140\n", " dt0 = 0.1\n", " args = (0.1, 0.02, 0.4, 0.02)\n", " sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args)\n", " ((prey,), _) = sol.ys\n", " return prey\n", "\n", "\n", "y0 = (jnp.array(10.0), jnp.array(10.0))\n", "run(y0)" ] }, { "cell_type": "markdown", "id": "a3ec6532-5b0a-4e4c-af33-bef58c0a7319", "metadata": {}, "source": [ "Note the use of the `scan_kind` argument to `Tsit5`. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse-mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.)\n", "\n", "In similar fashion, if using `saveat=SaveAt(ts=...)` (or a handful of other esoteric cases) then you will need to pass `adjoint=DirectAdjoint()`. (In this case: for the loop-over-saving output.)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:light" }, "kernelspec": { "display_name": "py39", "language": "python", "name": "py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }