How to choose a solver¤
Ordinary differential equations¤
The full list of ODE solvers is available on the ODE solvers page.
Info
ODE problems are informally divided into "stiff" and "non-stiff" problems. "Stiffness" generally refers to how difficult an equation is to solve numerically. Non-stiff problems are quite common, and usually solved using straightforward techniques like explicit Runge--Kutta methods. Stiff problems usually require more computationally expensive techniques, like implicit Runge--Kutta methods.
Non-stiff problems¤
For non-stiff problems then diffrax.Tsit5
is a good general-purpose solver.
Note
For a long time the recommend default solver for many problems was diffrax.Dopri5
. This is the default solver used in torchdiffeq
, and is the solver used in MATLAB's ode45
. However Tsit5
is now reckoned on being slightly more efficient overall. (Try both if you wish.)
If you need accurate solutions at tight tolerances then try diffrax.Dopri8
.
If you are solving a neural differential equation, and training via discretise-then-optimise (corresponding to diffeqsolve(..., adjoint=RecursiveCheckpointAdjoint())
, which is the default), then accurate solutions are often not needed and a low-order solver will be most efficient. For example something like diffrax.Heun
.
Stiff problems¤
For stiff problems then try the diffrax.Kvaerno3
, diffrax.Kvaerno4
, diffrax.Kvaerno5
family of solvers. In addition you should almost always use an adaptive step size controller such as diffrax.PIDController
.
See also the Stiff ODE example.
Danger
If solving a differential equation (stiff or not) to relatively high tolerances (typically \(10^{-8}\) or lower) then you should make sure to use 64-bit precision, instead of JAX's default 32-bit precision. Not doing so can introduce a variety of interesting errors. For example the following are all symptoms of having failed to do this:
NaN
gradients;- Taking many more solver steps than necessary (e.g. 8 steps -> 800 steps);
- Wrapping with
jax.value_and_grad
orjax.grad
actually changing the result of the primal (forward) computation.
Split problems¤
For "split stiffness" problems, with one term that is stiff and another term that is non-stiff, then IMEX methods are appropriate: diffrax.KenCarp4
is recommended. In addition you should almost always use an adaptive step size controller such as diffrax.PIDController
.
Stochastic differential equations¤
SDE solvers are relatively specialised depending on the type of problem. Each solver will converge to either the Itô solution or the Stratonovich solution. In addition some solvers require "commutative noise".
Commutative noise
Consider the SDE
\(\mathrm{d}y(t) = μ(t, y(t))\mathrm{d}t + σ(t, y(t))\mathrm{d}w(t)\)
then the diffusion matrix \(σ\) is said to satisfy the commutativity condition if
\(\sum_{i=1}^d σ_{i, j} \frac{\partial σ_{k, l}}{\partial y_i} = \sum_{i=1}^d σ_{i, l} \frac{\partial σ_{k, j}}{\partial y_i}\)
Some common special cases in which this condition is satisfied are:
- the diffusion is additive (\(σ\) is independent of \(y\));
- the noise is scalar (\(w\) is scalar-valued);
- the diffusion is diagonal (\(σ\) is a diagonal matrix and such that the i-th diagonal entry depends only on \(y_i\); not to be confused with the simpler but insufficient condition that \(σ\) is only a diagonal matrix)
Itô¤
For Itô SDEs:
- If the noise is commutative then
diffrax.ItoMilstein
is a typical choice; - If the noise is noncommutative then
diffrax.Euler
is a typical choice.
Stratonovich¤
For Stratonovich SDEs:
- If cheap low-accuracy solves are desired then
diffrax.EulerHeun
is a typical choice. - Otherwise, and if the noise is commutative, then
diffrax.SlowRK
has the best order of convergence, but is expensive per step.diffrax.StratonovichMilstein
is a good cheap alternative. - If the noise is noncommutative,
diffrax.GeneralShARK
is the most efficient choice, whilediffrax.Heun
is a good cheap alternative. - If the noise is noncommutative and an embedded method for adaptive step size control is desired, then
diffrax.SPaRK
is the recommended choice.
Additive noise¤
Consider the SDE
\(\mathrm{d}y(t) = μ(t, y(t))\mathrm{d}t + σ(t, y(t))\mathrm{d}w(t)\)
Then the diffusion matrix \(σ\) is said to be additive if \(σ(t, y) = σ(t)\). That is to say if the diffusion is independent of \(y\).
In this case the Itô solution and the Stratonovich solution coincide, and mathematically speaking the choice of Itô vs Stratonovich is unimportant. Special solvers for additive-noise SDEs tend to do particularly well as compared to the general Itô or Stratonovich solvers discussed above.
- The cheapest (but least accurate) solver is
diffrax.SEA
. - Otherwise
diffrax.ShARK
ordiffrax.SRA1
are good choices.
Underdamped Langevin Diffusion¤
The Underdamped Langevin Diffusion is a special case of an SDE with additive noise. For details on the form of this SDE and appropriate solvers, please refer to the section on Underdamped Langevin solvers.
Controlled differential equations¤
As an ODE¤
If the control is differentiable (e.g. an interpolation of discrete data) and isn't somehow "rough" (i.e. doesn't wiggle up and down at a very fine timescale that is smaller than you want to make numerical steps) then probably the best way to solve the CDE is to reduce it to an ODE:
vector_field = ...
control = ...
term = ControlTerm(vector_field, control)
term = term.to_ode()
Then use any of the ODE solvers as discussed above.
Directly discretising the control¤
The other option is to directly discretise the control. Given some control \(x \colon [0, T] \to \mathbb{R}^d\) then this means solving \(\mathrm{d}y(t) = f(y(t)) \mathrm{d} x(t)\) by treating \(x\) a bit like time, and replacing the \(\Delta t\) in most numerical solvers with some \(\Delta x(t)\) instead.
(This is actually the principle on which many SDE solvers work.)
It is an open question what are the best solvers to use when taking this approach, but low-order solvers are typical: for example diffrax.Euler
or diffrax.Heun
.