Nonlinear heat PDE¤
Diffrax can also be used to solve some PDEs.
(Specifically, the scope of Diffrax is "any numerical method which iterates over timesteps". This means that e.g. semidiscretised evolution equations are inscope, but e.g. finite volume methods for elliptic equations are outofscope.)
In this example, we solve the nonlinear heat equation
subject to the initial condition
and Dirichlet boundary conditions
We spatially discretise \(x \in [1, 1]\) into points \(1 = x_0 < x_1 < \cdots < x_{n1} = 1\), with equal spacing \(\delta x = x_{i+1}  x_i\). The solution is then discretised into \(y(t, x_i) \approx y_i(t)\), and the Laplacian discretised into \(\Delta y(t,x_i) \approx \frac{y_{i+1}(t)  2y_{i}(t) + y_{i1}(t)}{\delta x^2}\).
In doing so we reduce to a system of ODEs
subject to the initial condition
for which the Dirichlet boundary conditions become
This example is available as a Jupyter notebook here.
Advanced example
This is an advanced example, as it involves defining a custom solver.
from typing import Callable
import diffrax
import equinox as eqx # https://github.com/patrickkidger/equinox
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
jax.config.update("jax_enable_x64", True)
# Represents the interval [x0, x_final] discretised into n equallyspaced points.
class SpatialDiscretisation(eqx.Module):
x0: float = eqx.static_field()
x_final: float = eqx.static_field()
vals: Float[Array, "n"]
@classmethod
def discretise_fn(cls, x0: float, x_final: float, n: int, fn: Callable):
if n < 2:
raise ValueError("Must discretise [x0, x_final] into at least two points")
vals = jax.vmap(fn)(jnp.linspace(x0, x_final, n))
return cls(x0, x_final, vals)
@property
def δx(self):
return (self.x_final  self.x0) / (len(self.vals)  1)
def binop(self, other, fn):
if isinstance(other, SpatialDiscretisation):
if self.x0 != other.x0 or self.x_final != other.x_final:
raise ValueError("Mismatched spatial discretisations")
other = other.vals
return SpatialDiscretisation(self.x0, self.x_final, fn(self.vals, other))
def __add__(self, other):
return self.binop(other, lambda x, y: x + y)
def __mul__(self, other):
return self.binop(other, lambda x, y: x * y)
def __radd__(self, other):
return self.binop(other, lambda x, y: y + x)
def __rmul__(self, other):
return self.binop(other, lambda x, y: y * x)
def __sub__(self, other):
return self.binop(other, lambda x, y: x  y)
def __rsub__(self, other):
return self.binop(other, lambda x, y: y  x)
def laplacian(y: SpatialDiscretisation) > SpatialDiscretisation:
y_next = jnp.roll(y.vals, shift=1)
y_prev = jnp.roll(y.vals, shift=1)
Δy = (y_next  2 * y.vals + y_prev) / (y.δx**2)
# Dirichlet boundary condition
Δy = Δy.at[0].set(0)
Δy = Δy.at[1].set(0)
return SpatialDiscretisation(y.x0, y.x_final, Δy)
First let's try solving this semidiscretisation directly, as a system of ODEs.
# Problem
def vector_field(t, y, args):
return (1  y) * laplacian(y)
term = diffrax.ODETerm(vector_field)
ic = lambda x: x**2
# Spatial discretisation
x0 = 1
x_final = 1
n = 50
y0 = SpatialDiscretisation.discretise_fn(x0, x_final, n, ic)
# Temporal discretisation
t0 = 0
t_final = 1
δt = 0.0001
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 50))
# Tolerances
rtol = 1e10
atol = 1e10
stepsize_controller = diffrax.PIDController(
pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol, dtmax=0.001
)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t_final,
δt,
y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=None,
)
plt.figure(figsize=(5, 5))
plt.imshow(
sol.ys.vals,
origin="lower",
extent=(x0, x_final, t0, t_final),
aspect=(x_final  x0) / (t_final  t0),
cmap="inferno",
)
plt.xlabel("x")
plt.ylabel("t", rotation=0)
plt.clim(0, 1)
plt.colorbar()
plt.show()
That worked!
However, for more complicated PDEs then we may wish to define a custom solver. So as an example, here's how to solve the same PDE using the famous Crank–Nicolson scheme.
(See the page on abstract solvers for more details about how to define a custom solver.)
class CrankNicolson(diffrax.AbstractSolver):
rtol: float
atol: float
term_structure = diffrax.ODETerm
interpolation_cls = diffrax.LocalLinearInterpolation
def order(self, terms):
return 2
def init(self, terms, t0, t1, y0, args):
return None
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
del solver_state, made_jump
δt = t1  t0
f0 = terms.vf(t0, y0, args)
def keep_iterating(val):
_, not_converged = val
return not_converged
def fixed_point_iteration(val):
y1, _ = val
new_y1 = y0 + 0.5 * δt * (f0 + terms.vf(t1, y1, args))
diff = jnp.abs((new_y1  y1).vals)
max_y1 = jnp.maximum(jnp.abs(y1.vals), jnp.abs(new_y1.vals))
scale = self.atol + self.rtol * max_y1
not_converged = jnp.any(diff > scale)
return new_y1, not_converged
euler_y1 = y0 + δt * f0
y1, _ = lax.while_loop(keep_iterating, fixed_point_iteration, (euler_y1, False))
y_error = y1  euler_y1
dense_info = dict(y0=y0, y1=y1)
solver_state = None
result = diffrax.RESULTS.successful
return y1, y_error, dense_info, solver_state, result
def func(self, terms, t0, y0, args):
return terms.vf(t0, y0, args)
solver = CrankNicolson(rtol=rtol, atol=atol)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t_final,
δt,
y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=None,
)
plt.figure(figsize=(5, 5))
plt.imshow(
sol.ys.vals,
origin="lower",
extent=(x0, x_final, t0, t_final),
aspect=(x_final  x0) / (t_final  t0),
cmap="inferno",
)
plt.xlabel("x")
plt.ylabel("t", rotation=0)
plt.clim(0, 1)
plt.colorbar()
plt.show()
Some final notes.

We wrote down the general Crank–Nicolson method, which uses a fixed point iteration to solve the implicit problem. If you know something about the structure of your problem (e.g. that it is linear) then it is often possible to more specialised solvers, which run faster. (E.g. linear solvers.)

To keep this example brief, we didn't worry about doing a von Neumann stability analysis.