Skip to content

Nonlinear heat PDE¤

Diffrax can also be used to solve some PDEs.

(Specifically, the scope of Diffrax is "any numerical method which iterates over timesteps". This means that e.g. semidiscretised evolution equations are in-scope, but e.g. finite volume methods for elliptic equations are out-of-scope.)


In this example, we solve the nonlinear heat equation

\[ \frac{\partial y}{\partial t}(t, x) = (1 - y(t, x)) \Delta y(t, x) \qquad\text{in}\qquad t \in [0, 40], x \in [-1, 1]\]

subject to the initial condition

\[ y(0, x) = x^2, \]

and Dirichlet boundary conditions

\[ y(t, -1) = 1,\qquad y(t, 1) = 1. \]

We spatially discretise \(x \in [-1, 1]\) into points \(-1 = x_0 < x_1 < \cdots < x_{n-1} = 1\), with equal spacing \(\delta x = x_{i+1} - x_i\). The solution is then discretised into \(y(t, x_i) \approx y_i(t)\), and the Laplacian discretised into \(\Delta y(t,x_i) \approx \frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{\delta x^2}\).

In doing so we reduce to a system of ODEs

\[ \frac{\mathrm{d}y_i}{\mathrm{d}t}(t) = (1 - y_i(t)) \frac{y_{i+1}(t) - 2y_{i}(t) + y_{i-1}(t)}{\delta x^2} \qquad\text{for}\qquad i \in \{1, ..., n-2\},\]

subject to the initial condition

\[ y_i(0) = {x_i}^2, \]

for which the Dirichlet boundary conditions become

\[ \frac{\mathrm{d}y_0}{\mathrm{d}t}(t) = 0,\qquad \frac{\mathrm{d}y_{n-1}}{\mathrm{d}t}(t) = 0. \]

This example is available as a Jupyter notebook here.

Advanced example

This is an advanced example, as it involves defining a custom solver.

from typing import Callable

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float  # https://github.com/google/jaxtyping


jax.config.update("jax_enable_x64", True)
# Represents the interval [x0, x_final] discretised into n equally-spaced points.
class SpatialDiscretisation(eqx.Module):
    x0: float = eqx.field(static=True)
    x_final: float = eqx.field(static=True)
    vals: Float[Array, "n"]

    @classmethod
    def discretise_fn(cls, x0: float, x_final: float, n: int, fn: Callable):
        if n < 2:
            raise ValueError("Must discretise [x0, x_final] into at least two points")
        vals = jax.vmap(fn)(jnp.linspace(x0, x_final, n))
        return cls(x0, x_final, vals)

    @property
    def δx(self):
        return (self.x_final - self.x0) / (len(self.vals) - 1)

    def binop(self, other, fn):
        if isinstance(other, SpatialDiscretisation):
            if self.x0 != other.x0 or self.x_final != other.x_final:
                raise ValueError("Mismatched spatial discretisations")
            other = other.vals
        return SpatialDiscretisation(self.x0, self.x_final, fn(self.vals, other))

    def __add__(self, other):
        return self.binop(other, lambda x, y: x + y)

    def __mul__(self, other):
        return self.binop(other, lambda x, y: x * y)

    def __radd__(self, other):
        return self.binop(other, lambda x, y: y + x)

    def __rmul__(self, other):
        return self.binop(other, lambda x, y: y * x)

    def __sub__(self, other):
        return self.binop(other, lambda x, y: x - y)

    def __rsub__(self, other):
        return self.binop(other, lambda x, y: y - x)


def laplacian(y: SpatialDiscretisation) -> SpatialDiscretisation:
    y_next = jnp.roll(y.vals, shift=1)
    y_prev = jnp.roll(y.vals, shift=-1)
    Δy = (y_next - 2 * y.vals + y_prev) / (y.δx**2)
    # Dirichlet boundary condition
    Δy = Δy.at[0].set(0)
    Δy = Δy.at[-1].set(0)
    return SpatialDiscretisation(y.x0, y.x_final, Δy)

First let's try solving this semidiscretisation directly, as a system of ODEs.

# Problem
def vector_field(t, y, args):
    return (1 - y) * laplacian(y)


term = diffrax.ODETerm(vector_field)
ic = lambda x: x**2

# Spatial discretisation
x0 = -1
x_final = 1
n = 50
y0 = SpatialDiscretisation.discretise_fn(x0, x_final, n, ic)

# Temporal discretisation
t0 = 0
t_final = 1
δt = 0.0001
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 50))

# Tolerances
rtol = 1e-10
atol = 1e-10
stepsize_controller = diffrax.PIDController(
    pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol, dtmax=0.001
)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(
    term,
    solver,
    t0,
    t_final,
    δt,
    y0,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    max_steps=None,
)
plt.figure(figsize=(5, 5))
plt.imshow(
    sol.ys.vals,
    origin="lower",
    extent=(x0, x_final, t0, t_final),
    aspect=(x_final - x0) / (t_final - t0),
    cmap="inferno",
)
plt.xlabel("x")
plt.ylabel("t", rotation=0)
plt.clim(0, 1)
plt.colorbar()
plt.show()

That worked!

However, for more complicated PDEs then we may wish to define a custom solver. So as an example, here's how to solve the same PDE using the famous Crank–Nicolson scheme.

(See the page on abstract solvers for more details about how to define a custom solver.)

class CrankNicolson(diffrax.AbstractSolver):
    rtol: float
    atol: float

    term_structure = diffrax.ODETerm
    interpolation_cls = diffrax.LocalLinearInterpolation

    def order(self, terms):
        return 2

    def init(self, terms, t0, t1, y0, args):
        return None

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        del solver_state, made_jump
        δt = t1 - t0
        f0 = terms.vf(t0, y0, args)

        def keep_iterating(val):
            _, not_converged = val
            return not_converged

        def fixed_point_iteration(val):
            y1, _ = val
            new_y1 = y0 + 0.5 * δt * (f0 + terms.vf(t1, y1, args))
            diff = jnp.abs((new_y1 - y1).vals)
            max_y1 = jnp.maximum(jnp.abs(y1.vals), jnp.abs(new_y1.vals))
            scale = self.atol + self.rtol * max_y1
            not_converged = jnp.any(diff > scale)
            return new_y1, not_converged

        euler_y1 = y0 + δt * f0
        y1, _ = lax.while_loop(keep_iterating, fixed_point_iteration, (euler_y1, False))

        y_error = y1 - euler_y1
        dense_info = dict(y0=y0, y1=y1)

        solver_state = None
        result = diffrax.RESULTS.successful
        return y1, y_error, dense_info, solver_state, result

    def func(self, terms, t0, y0, args):
        return terms.vf(t0, y0, args)
solver = CrankNicolson(rtol=rtol, atol=atol)
sol = diffrax.diffeqsolve(
    term,
    solver,
    t0,
    t_final,
    δt,
    y0,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
    max_steps=None,
)
plt.figure(figsize=(5, 5))
plt.imshow(
    sol.ys.vals,
    origin="lower",
    extent=(x0, x_final, t0, t_final),
    aspect=(x_final - x0) / (t_final - t0),
    cmap="inferno",
)
plt.xlabel("x")
plt.ylabel("t", rotation=0)
plt.clim(0, 1)
plt.colorbar()
plt.show()

Some final notes.

  1. We wrote down the general Crank–Nicolson method, which uses a fixed point iteration to solve the implicit problem. If you know something about the structure of your problem (e.g. that it is linear) then it is often possible to more specialised solvers, which run faster. (E.g. linear solvers.)

  2. To keep this example brief, we didn't worry about doing a von Neumann stability analysis.