Skip to content

Steady statesยค

This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.

In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.

This example is available as a Jupyter notebook here.

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
class ExponentialDecayToSteadyState(eqx.Module):
    steady_state: float

    def __call__(self, t, y, args):
        return self.steady_state - y
def loss(model, target_steady_state):
    term = diffrax.ODETerm(model)
    solver = diffrax.Tsit5()
    t0 = 0
    t1 = jnp.inf
    dt0 = None
    y0 = 1.0
    max_steps = None
    controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
    event = diffrax.SteadyStateEvent()
    adjoint = diffrax.ImplicitAdjoint()
    # This combination of event, t1, max_steps, adjoint is particularly
    # natural: we keep integration forever until we hit the event, with
    # no maximum time or number of steps. Backpropagation happens via
    # the implicit function theorem.
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        max_steps=max_steps,
        stepsize_controller=controller,
        discrete_terminating_event=event,
        adjoint=adjoint,
    )
    (y1,) = sol.ys
    return (y1 - target_steady_state) ** 2
model = ExponentialDecayToSteadyState(
    jnp.array(0.0)
)  # initial steady state guess is 0.
# target steady state is 0.76
target_steady_state = jnp.array(0.76)
optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)
opt_state = optim.init(model)


@eqx.filter_jit
def make_step(model, opt_state, target_steady_state):
    grads = eqx.filter_grad(loss)(model, target_steady_state)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state


for step in range(100):
    model, opt_state = make_step(model, opt_state, target_steady_state)
    print(f"Step: {step} Steady State: {model.steady_state}")
print(f"Target: {target_steady_state}")
Step: 0 Steady State: 0.025839969515800476
Step: 1 Steady State: 0.058249037712812424
Step: 2 Steady State: 0.09451574087142944
Step: 3 Steady State: 0.13270404934883118
Step: 4 Steady State: 0.17144456505775452
Step: 5 Steady State: 0.2097906768321991
Step: 6 Steady State: 0.24709917604923248
Step: 7 Steady State: 0.28294336795806885
Step: 8 Steady State: 0.3170691728591919
Step: 9 Steady State: 0.34933507442474365
Step: 10 Steady State: 0.37968066334724426
Step: 11 Steady State: 0.4081019163131714
Step: 12 Steady State: 0.43463483452796936
Step: 13 Steady State: 0.45934173464775085
Step: 14 Steady State: 0.4823019802570343
Step: 15 Steady State: 0.5035936236381531
Step: 16 Steady State: 0.5233209133148193
Step: 17 Steady State: 0.5415788888931274
Step: 18 Steady State: 0.5584676265716553
Step: 19 Steady State: 0.5740787982940674
Step: 20 Steady State: 0.5885017514228821
Step: 21 Steady State: 0.6018210053443909
Step: 22 Steady State: 0.6141175627708435
Step: 23 Steady State: 0.6254667043685913
Step: 24 Steady State: 0.6359376907348633
Step: 25 Steady State: 0.6455990076065063
Step: 26 Steady State: 0.6545112729072571
Step: 27 Steady State: 0.6627309322357178
Step: 28 Steady State: 0.6703115701675415
Step: 29 Steady State: 0.6773026585578918
Step: 30 Steady State: 0.6837494373321533
Step: 31 Steady State: 0.6896938681602478
Step: 32 Steady State: 0.6951748728752136
Step: 33 Steady State: 0.7002284526824951
Step: 34 Steady State: 0.7048872113227844
Step: 35 Steady State: 0.7091819047927856
Step: 36 Steady State: 0.7131412029266357
Step: 37 Steady State: 0.7167739868164062
Step: 38 Steady State: 0.7201183438301086
Step: 39 Steady State: 0.7231980562210083
Step: 40 Steady State: 0.7260348796844482
Step: 41 Steady State: 0.7286462187767029
Step: 42 Steady State: 0.7310511469841003
Step: 43 Steady State: 0.733269989490509
Step: 44 Steady State: 0.7353137731552124
Step: 45 Steady State: 0.7371994853019714
Step: 46 Steady State: 0.7389383912086487
Step: 47 Steady State: 0.740541934967041
Step: 48 Steady State: 0.7420334219932556
Step: 49 Steady State: 0.7434003353118896
Step: 50 Steady State: 0.7446598410606384
Step: 51 Steady State: 0.7458205819129944
Step: 52 Steady State: 0.7468900680541992
Step: 53 Steady State: 0.7478761672973633
Step: 54 Steady State: 0.7487852573394775
Step: 55 Steady State: 0.7496234178543091
Step: 56 Steady State: 0.750394344329834
Step: 57 Steady State: 0.7511063814163208
Step: 58 Steady State: 0.751763105392456
Step: 59 Steady State: 0.7523672580718994
Step: 60 Steady State: 0.7529228329658508
Step: 61 Steady State: 0.753433346748352
Step: 62 Steady State: 0.7539049983024597
Step: 63 Steady State: 0.7543382048606873
Step: 64 Steady State: 0.7547407746315002
Step: 65 Steady State: 0.7551127672195435
Step: 66 Steady State: 0.7554563879966736
Step: 67 Steady State: 0.7557693123817444
Step: 68 Steady State: 0.7560611367225647
Step: 69 Steady State: 0.7563308477401733
Step: 70 Steady State: 0.7565800547599792
Step: 71 Steady State: 0.756810188293457
Step: 72 Steady State: 0.7570226788520813
Step: 73 Steady State: 0.7572163343429565
Step: 74 Steady State: 0.7573966979980469
Step: 75 Steady State: 0.7575633525848389
Step: 76 Steady State: 0.7577127814292908
Step: 77 Steady State: 0.7578537464141846
Step: 78 Steady State: 0.7579842805862427
Step: 79 Steady State: 0.7581048607826233
Step: 80 Steady State: 0.7582123279571533
Step: 81 Steady State: 0.7583134770393372
Step: 82 Steady State: 0.7584078907966614
Step: 83 Steady State: 0.7584953904151917
Step: 84 Steady State: 0.758575975894928
Step: 85 Steady State: 0.7586501836776733
Step: 86 Steady State: 0.7587193250656128
Step: 87 Steady State: 0.7587832808494568
Step: 88 Steady State: 0.7588424682617188
Step: 89 Steady State: 0.7588958144187927
Step: 90 Steady State: 0.7589460015296936
Step: 91 Steady State: 0.7589924931526184
Step: 92 Steady State: 0.7590354681015015
Step: 93 Steady State: 0.7590752243995667
Step: 94 Steady State: 0.7591111063957214
Step: 95 Steady State: 0.7591448426246643
Step: 96 Steady State: 0.7591760754585266
Step: 97 Steady State: 0.7592049241065979
Step: 98 Steady State: 0.7592315673828125
Step: 99 Steady State: 0.7592562437057495
Target: 0.7599999904632568