Diffrax: JAX-based numerical differential equation solvers

Patrick Kidger

github.com/patrick-kidger/diffrax
@PatrickKidger
--split-- !!! abstract "Summary of features" - Ordinary/stochastic/controlled diffeq solvers; - High-order, implicit, symplectic solvers; - Using a PyTree as the state; - Dense solutions; - Multiple adjoint methods for backpropagation. ## Easy-to-use syntax Let's solve the ODE \\(\frac{\mathrm{d}y}{\mathrm{d}t} = -y\\): ```python from diffrax import diffeqsolve, ODETerm, Dopri5 import jax.numpy as jnp def f(t, y, args): return -y term = ODETerm(f) solver = Dopri5() solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=jnp.array([2., 3.])) ``` ## New idea: unified solving At a technical level, the internal structure of the library does some pretty cool new stuff! Most important is the idea of solving ODEs and SDEs in a single unified way; this produces a small tightly-written library. Specifically: ordinary differential equations \\[\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))\\] and stochastic differential equations \\[\mathrm{d}y(t) = f_1(t, y(t))\,\mathrm{d}t + f_2(t, y(t))\,\mathrm{d}w(t)\\] are solved in a unified way by lowering them to *controlled* diffeqs: \\[\mathrm{d}y(t) = f(t, y(t)) \,\mathrm{d}x(t).\\] where e.g. \\(x(t) = t\\) for an ODE and \\(x(t) = [t, w(t)]\\) for an SDE. --split-- ## Versus other libraries?
(torchdiffeq, Julia etc.) Diffrax is better for advanced use cases: - Adding your own custom ops; - Solving ODEs/SDEs simultaneously; - Solving SDEs with controls, or multiple noise terms. - etc. Diffrax is also *fast*. - 1.3--20 times faster than torchdiffeq. - Similar speed to DifferentialEquations.jl (precise benchmarks WIP). ## Extending Diffrax Diffrax is designed to be highly extensible. - There are a sophisticated collection of abstract base classes (`AbstractSolver` etc.) through which you can easily add custom ops. - If you're writing e.g. a differentiable simulator and want to step through the solve yourself, then this is also possible. !!! tip "Next steps" **Installation:** `pip install diffrax`
**Documentation:** [https://docs.kidger.site/diffrax](https://docs.kidger.site/diffrax)
**Reference:** 
P. Kidger, *On Neural Differential Equations*, Doctoral Thesis, University of Oxford 2021
```bibtex @phdthesis{kidger2021on, title={{O}n {N}eural {D}ifferential {E}quations}, author={Patrick Kidger}, year={2021}, school={University of Oxford}, } ```