Interactively step through a solveยค
Sometimes you might want to do perform a differential equation solve just one step at a time (or a few steps at a time), and perhaps do some other computations in between. A common example is when solving a differential equation in real time, and wanting to continually produce some output.
One option is to repeatedly call diffrax.diffeqsolve
. However if that seems inelegant/inefficient to you, then it is possible to use the solvers (and step size controllers, etc.) yourself directly.
In the following example, we solve an ODE using diffrax.Tsit5
, and print out the result as we go.
Note
See the Abstract solvers page for a reference on the solver methods (init
, step
) used here.
import jax.numpy as jnp
from diffrax import ODETerm, Tsit5
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Tsit5()
t0 = 0
dt0 = 0.05
t1 = 1
y0 = jnp.array(1.0)
args = None
tprev = t0
tnext = t0 + dt0
y = y0
state = solver.init(term, tprev, tnext, y0, args)
while tprev < t1:
y, _, _, state, _ = solver.step(term, tprev, tnext, y, args, state, made_jump=False)
print(f"At time {tnext} obtained value {y}")
tprev = tnext
tnext = min(tprev + dt0, t1)