This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.

In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.

This example is available as a Jupyter notebook here.

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax

class ExponentialDecayToSteadyState(eqx.Module):

def __call__(self, t, y, args):

def loss(model, target_steady_state):
term = diffrax.ODETerm(model)
solver = diffrax.Tsit5()
t0 = 0
t1 = jnp.inf
dt0 = None
y0 = 1.0
max_steps = None
controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
# This combination of event, t1, max_steps, adjoint is particularly
# natural: we keep integration forever until we hit the event, with
# no maximum time or number of steps. Backpropagation happens via
# the implicit function theorem.
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
max_steps=max_steps,
stepsize_controller=controller,
discrete_terminating_event=event,
)
(y1,) = sol.ys
return (y1 - target_steady_state) ** 2

model = ExponentialDecayToSteadyState(
jnp.array(0.0)
)  # initial steady state guess is 0.
# target steady state is 0.76
optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)
opt_state = optim.init(model)

@eqx.filter_jit
return model, opt_state

for step in range(100):
model, opt_state = make_step(model, opt_state, target_steady_state)

Step: 0 Steady State: 0.025839969515800476