Skip to content

Interactively step through a solveยค

Sometimes you might want to do perform a differential equation solve just one step at a time (or a few steps at a time), and perhaps do some other computations in between. A common example is when solving a differential equation in real time, and wanting to continually produce some output.

One option is to repeatedly call diffrax.diffeqsolve. However if that seems inelegant/inefficient to you, then it is possible to use the solvers (and step size controllers, etc.) yourself directly.

In the following example, we solve an ODE using diffrax.Tsit5, and print out the result as we go.

Note

See the Abstract solvers page for a reference on the solver methods (init, step) used here.

import jax.numpy as jnp
from diffrax import ODETerm, Tsit5

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Tsit5()

t0 = 0
dt0 = 0.05
t1 = 1
y0 = jnp.array(1.0)
args = None

tprev = t0
tnext = t0 + dt0
y = y0
state = solver.init(term, tprev, tnext, y0, args)

while tprev < t1:
    y, _, _, state, _ = solver.step(term, tprev, tnext, y, args, state, made_jump=False)
    print(f"At time {tnext} obtained value {y}")
    tprev = tnext
    tnext = min(tprev + dt0, t1)