Steady statesยค
This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.
In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.
This example is available as a Jupyter notebook here.
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp
import optax # https://github.com/deepmind/optax
class ExponentialDecayToSteadyState(eqx.Module):
steady_state: float
def __call__(self, t, y, args):
return self.steady_state - y
def loss(model, target_steady_state):
term = diffrax.ODETerm(model)
solver = diffrax.Tsit5()
t0 = 0
t1 = jnp.inf
dt0 = None
y0 = 1.0
max_steps = None
controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
cond_fn = diffrax.steady_state_event()
event = diffrax.Event(cond_fn)
adjoint = diffrax.ImplicitAdjoint()
# This combination of event, t1, max_steps, adjoint is particularly
# natural: we keep integration forever until we hit the event, with
# no maximum time or number of steps. Backpropagation happens via
# the implicit function theorem.
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
max_steps=max_steps,
stepsize_controller=controller,
event=event,
adjoint=adjoint,
)
(y1,) = sol.ys
return (y1 - target_steady_state) ** 2
model = ExponentialDecayToSteadyState(
jnp.array(0.0)
) # initial steady state guess is 0.
# target steady state is 0.76
target_steady_state = jnp.array(0.76)
optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)
opt_state = optim.init(model)
@eqx.filter_jit
def make_step(model, opt_state, target_steady_state):
grads = eqx.filter_grad(loss)(model, target_steady_state)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state
for step in range(100):
model, opt_state = make_step(model, opt_state, target_steady_state)
print(f"Step: {step} Steady State: {model.steady_state}")
print(f"Target: {target_steady_state}")