This example demonstrates how to incorporate an external forcing term into the solve. This is really simple: just evaluate it as part of the vector field like anything else.
This example is available as a Jupyter notebook here.
import jax import jax.numpy as jnp import matplotlib.pyplot as plt from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5 def force(t, args): m, c = args return m * t + c def vector_field(t, y, args): return -y + force(t, args) @jax.jit def solve(y0, args): term = ODETerm(vector_field) solver = Tsit5() t0 = 0 t1 = 10 dt0 = 0.1 saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000)) sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat) return sol y0 = 1.0 args = (0.1, 0.02) sol = solve(y0, args)
plt.plot(sol.ts, sol.ys) plt.xlabel("t") plt.ylabel("y") plt.show()
Now let's consider a more complicated example: the forcing term is an interpolation, and what's more we would like to differentiate with respect to the values we are interpolating.
from diffrax import backward_hermite_coefficients, CubicInterpolation def vector_field2(t, y, interp): return -y + interp.evaluate(t) @jax.jit @jax.grad def solve(points): t0 = 0 t1 = 10 ts = jnp.linspace(t0, t1, len(points)) coeffs = backward_hermite_coefficients(ts, points) interp = CubicInterpolation(ts, coeffs) term = ODETerm(vector_field2) solver = Tsit5() dt0 = 0.1 y0 = 1.0 sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=interp) (y1,) = sol.ys return y1 points = jnp.array([3.0, 0.5, -0.8, 1.8]) grads = solve(points)
In this example, we computed the interpolation in advance (not repeatedly on each step!), and then just evaluated it inside the vector field.