Solving an ODE with a forcing termยค
This example demonstrates how to incorporate an external forcing term into the solve. This is really simple: just evaluate it as part of the vector field like anything else.
This example is available as a Jupyter notebook here.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
def force(t, args):
m, c = args
return m * t + c
def vector_field(t, y, args):
return -y + force(t, args)
@jax.jit
def solve(y0, args):
term = ODETerm(vector_field)
solver = Tsit5()
t0 = 0
t1 = 10
dt0 = 0.1
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)
return sol
y0 = 1.0
args = (0.1, 0.02)
sol = solve(y0, args)
plt.plot(sol.ts, sol.ys)
plt.xlabel("t")
plt.ylabel("y")
plt.show()
Now let's consider a more complicated example: the forcing term is an interpolation, and what's more we would like to differentiate with respect to the values we are interpolating.
from diffrax import backward_hermite_coefficients, CubicInterpolation
def vector_field2(t, y, interp):
return -y + interp.evaluate(t)
@jax.jit
@jax.grad
def solve(points):
t0 = 0
t1 = 10
ts = jnp.linspace(t0, t1, len(points))
coeffs = backward_hermite_coefficients(ts, points)
interp = CubicInterpolation(ts, coeffs)
term = ODETerm(vector_field2)
solver = Tsit5()
dt0 = 0.1
y0 = 1.0
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=interp)
(y1,) = sol.ys
return y1
points = jnp.array([3.0, 0.5, -0.8, 1.8])
grads = solve(points)
In this example, we computed the interpolation in advance (not repeatedly on each step!), and then just evaluated it inside the vector field.