Skip to content

Computing second-order sensitivities¤

This example demonstrates how to compute the Hessian of a differential equation solve.

This example is available as a Jupyter notebook here.

import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5


def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = d_prey, d_predator
    return d_y


@jax.jit
@jax.hessian
def run(y0):
    term = ODETerm(vector_field)
    solver = Tsit5(scan_kind="bounded")
    t0 = 0
    t1 = 140
    dt0 = 0.1
    args = (0.1, 0.02, 0.4, 0.02)
    sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args)
    ((prey,), _) = sol.ys
    return prey


y0 = (jnp.array(10.0), jnp.array(10.0))
run(y0)
((Array(3.9131193, dtype=float32, weak_type=True),
  Array(-2.374867, dtype=float32, weak_type=True)),
 (Array(-2.3748531, dtype=float32, weak_type=True),
  Array(1.688472, dtype=float32, weak_type=True)))

Note the use of the scan_kind argument to Tsit5. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse-mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.)

In similar fashion, if using saveat=SaveAt(ts=...) (or a handful of other esoteric cases) then you will need to pass adjoint=DirectAdjoint(). (In this case: for the loop-over-saving output.)