Computing second-order sensitivities¤
This example demonstrates how to compute the Hessian of a differential equation solve.
This example is available as a Jupyter notebook here.
import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
d_y = d_prey, d_predator
return d_y
@jax.jit
@jax.hessian
def run(y0):
term = ODETerm(vector_field)
solver = Tsit5(scan_kind="bounded")
t0 = 0
t1 = 140
dt0 = 0.1
args = (0.1, 0.02, 0.4, 0.02)
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args)
((prey,), _) = sol.ys
return prey
y0 = (jnp.array(10.0), jnp.array(10.0))
run(y0)
Note the use of the scan_kind
argument to Tsit5
. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse-mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.)
In similar fashion, if using saveat=SaveAt(ts=...)
(or a handful of other esoteric cases) then you will need to pass adjoint=DirectAdjoint()
. (In this case: for the loop-over-saving output.)