Skip to content

Steady statesยค

This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.

In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.

This example is available as a Jupyter notebook here.

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
class ExponentialDecayToSteadyState(eqx.Module):
    steady_state: float

    def __call__(self, t, y, args):
        return self.steady_state - y
def loss(model, target_steady_state):
    term = diffrax.ODETerm(model)
    solver = diffrax.Tsit5()
    t0 = 0
    t1 = jnp.inf
    dt0 = None
    y0 = 1.0
    max_steps = None
    controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
    cond_fn = diffrax.steady_state_event()
    event = diffrax.Event(cond_fn)
    adjoint = diffrax.ImplicitAdjoint()
    # This combination of event, t1, max_steps, adjoint is particularly
    # natural: we keep integration forever until we hit the event, with
    # no maximum time or number of steps. Backpropagation happens via
    # the implicit function theorem.
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        max_steps=max_steps,
        stepsize_controller=controller,
        event=event,
        adjoint=adjoint,
    )
    (y1,) = sol.ys
    return (y1 - target_steady_state) ** 2
model = ExponentialDecayToSteadyState(
    jnp.array(0.0)
)  # initial steady state guess is 0.
# target steady state is 0.76
target_steady_state = jnp.array(0.76)
optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)
opt_state = optim.init(model)


@eqx.filter_jit
def make_step(model, opt_state, target_steady_state):
    grads = eqx.filter_grad(loss)(model, target_steady_state)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state


for step in range(100):
    model, opt_state = make_step(model, opt_state, target_steady_state)
    print(f"Step: {step} Steady State: {model.steady_state}")
print(f"Target: {target_steady_state}")
Step: 0 Steady State: 0.025839969515800476
Step: 1 Steady State: 0.05824900045990944
Step: 2 Steady State: 0.09451568126678467
Step: 3 Steady State: 0.1327039748430252
Step: 4 Steady State: 0.1714443564414978
Step: 5 Steady State: 0.20979028940200806
Step: 6 Steady State: 0.24709881842136383
Step: 7 Steady State: 0.28294941782951355
Step: 8 Steady State: 0.31707584857940674
Step: 9 Steady State: 0.34934186935424805
Step: 10 Steady State: 0.37968698143959045
Step: 11 Steady State: 0.4081074893474579
Step: 12 Steady State: 0.43463948369026184
Step: 13 Steady State: 0.45934492349624634
Step: 14 Steady State: 0.48230400681495667
Step: 15 Steady State: 0.5036059021949768
Step: 16 Steady State: 0.5233321189880371
Step: 17 Steady State: 0.5415896773338318
Step: 18 Steady State: 0.5584752559661865
Step: 19 Steady State: 0.5740804076194763
Step: 20 Steady State: 0.5884985327720642
Step: 21 Steady State: 0.6018134951591492
Step: 22 Steady State: 0.6141058206558228
Step: 23 Steady State: 0.6254505515098572
Step: 24 Steady State: 0.6359192728996277
Step: 25 Steady State: 0.6455777287483215
Step: 26 Steady State: 0.6544871926307678
Step: 27 Steady State: 0.6627050638198853
Step: 28 Steady State: 0.6702842116355896
Step: 29 Steady State: 0.6772737503051758
Step: 30 Steady State: 0.6837191581726074
Step: 31 Steady State: 0.6896624565124512
Step: 32 Steady State: 0.6951420903205872
Step: 33 Steady State: 0.7001940608024597
Step: 34 Steady State: 0.7048525214195251
Step: 35 Steady State: 0.709147572517395
Step: 36 Steady State: 0.7131075263023376
Step: 37 Steady State: 0.7167584300041199
Step: 38 Steady State: 0.720124363899231
Step: 39 Steady State: 0.7232275605201721
Step: 40 Steady State: 0.7260884642601013
Step: 41 Steady State: 0.7287259697914124
Step: 42 Steady State: 0.7311574816703796
Step: 43 Steady State: 0.7333983778953552
Step: 44 Steady State: 0.7354647517204285
Step: 45 Steady State: 0.7373697757720947
Step: 46 Steady State: 0.7391260266304016
Step: 47 Steady State: 0.7407451272010803
Step: 48 Steady State: 0.7422377467155457
Step: 49 Steady State: 0.7436137795448303
Step: 50 Steady State: 0.7448822855949402
Step: 51 Steady State: 0.7460517287254333
Step: 52 Steady State: 0.7471297979354858
Step: 53 Steady State: 0.7481234669685364
Step: 54 Steady State: 0.7490396499633789
Step: 55 Steady State: 0.7498842477798462
Step: 56 Steady State: 0.7506628632545471
Step: 57 Steady State: 0.7513806223869324
Step: 58 Steady State: 0.7520219683647156
Step: 59 Steady State: 0.7526065707206726
Step: 60 Steady State: 0.7531405687332153
Step: 61 Steady State: 0.7536292672157288
Step: 62 Steady State: 0.754077136516571
Step: 63 Steady State: 0.7544881105422974
Step: 64 Steady State: 0.7548655867576599
Step: 65 Steady State: 0.7552322149276733
Step: 66 Steady State: 0.7555564045906067
Step: 67 Steady State: 0.7558530569076538
Step: 68 Steady State: 0.7561249732971191
Step: 69 Steady State: 0.7563938498497009
Step: 70 Steady State: 0.7566279768943787
Step: 71 Steady State: 0.7568415403366089
Step: 72 Steady State: 0.7570368051528931
Step: 73 Steady State: 0.7572155594825745
Step: 74 Steady State: 0.7573794722557068
Step: 75 Steady State: 0.7575299143791199
Step: 76 Steady State: 0.757668137550354
Step: 77 Steady State: 0.7577952742576599
Step: 78 Steady State: 0.7579122185707092
Step: 79 Steady State: 0.7580198645591736
Step: 80 Steady State: 0.7581189870834351
Step: 81 Steady State: 0.758210301399231
Step: 82 Steady State: 0.7583132982254028
Step: 83 Steady State: 0.7583956122398376
Step: 84 Steady State: 0.7584698796272278
Step: 85 Steady State: 0.7585371136665344
Step: 86 Steady State: 0.7585982084274292
Step: 87 Steady State: 0.7586538791656494
Step: 88 Steady State: 0.7587047219276428
Step: 89 Steady State: 0.7587512731552124
Step: 90 Steady State: 0.7587938904762268
Step: 91 Steady State: 0.7588329911231995
Step: 92 Steady State: 0.758868932723999
Step: 93 Steady State: 0.7589019536972046
Step: 94 Steady State: 0.7589322924613953
Step: 95 Steady State: 0.7589602470397949
Step: 96 Steady State: 0.7589859366416931
Step: 97 Steady State: 0.7590096592903137
Step: 98 Steady State: 0.7590314745903015
Step: 99 Steady State: 0.7590516209602356
Target: 0.7599999904632568