{ "cells": [ { "cell_type": "markdown", "id": "1fe536ed", "metadata": {}, "source": [ "# Solving an ODE with a forcing term" ] }, { "cell_type": "markdown", "id": "598ab169-05d8-4733-a6cc-9fa91aa92198", "metadata": {}, "source": [ "This example demonstrates how to incorporate an external forcing term into the solve. This is really simple: just evaluate it as part of the vector field like anything else.\n", "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/forcing.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "6d6bdf63", "metadata": { "tags": [] }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5\n", "\n", "\n", "def force(t, args):\n", " m, c = args\n", " return m * t + c\n", "\n", "\n", "def vector_field(t, y, args):\n", " return -y + force(t, args)\n", "\n", "\n", "@jax.jit\n", "def solve(y0, args):\n", " term = ODETerm(vector_field)\n", " solver = Tsit5()\n", " t0 = 0\n", " t1 = 10\n", " dt0 = 0.1\n", " saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))\n", " sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)\n", " return sol\n", "\n", "\n", "y0 = 1.0\n", "args = (0.1, 0.02)\n", "sol = solve(y0, args)" ] }, { "cell_type": "code", "execution_count": 2, "id": "9654fd84-19b9-4a0b-bff6-d20f36c4f333", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(sol.ts, sol.ys)\n", "plt.xlabel(\"t\")\n", "plt.ylabel(\"y\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2043029d-e78b-410d-9a96-5f6904f2ca05", "metadata": {}, "source": [ "Now let's consider a more complicated example: the forcing term is an interpolation, and what's more we would like to differentiate with respect to the values we are interpolating." ] }, { "cell_type": "code", "execution_count": 3, "id": "7a3f2ea2-0067-4999-a04c-5b445e7ab749", "metadata": { "tags": [] }, "outputs": [], "source": [ "from diffrax import backward_hermite_coefficients, CubicInterpolation\n", "\n", "\n", "def vector_field2(t, y, interp):\n", " return -y + interp.evaluate(t)\n", "\n", "\n", "@jax.jit\n", "@jax.grad\n", "def solve(points):\n", " t0 = 0\n", " t1 = 10\n", " ts = jnp.linspace(t0, t1, len(points))\n", " coeffs = backward_hermite_coefficients(ts, points)\n", " interp = CubicInterpolation(ts, coeffs)\n", " term = ODETerm(vector_field2)\n", " solver = Tsit5()\n", " dt0 = 0.1\n", " y0 = 1.0\n", " sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=interp)\n", " (y1,) = sol.ys\n", " return y1\n", "\n", "\n", "points = jnp.array([3.0, 0.5, -0.8, 1.8])\n", "grads = solve(points)" ] }, { "cell_type": "markdown", "id": "f34b4824-8420-4881-b5b7-78b2118de5e0", "metadata": {}, "source": [ "In this example, we computed the interpolation in advance (not repeatedly on each step!), and then just evaluated it inside the vector field." ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:light" }, "kernelspec": { "display_name": "py39", "language": "python", "name": "py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }