Skip to content

Adjoints¤

There are multiple ways to backpropagate through a differential equation (to compute the gradient of the solution with respect to its initial condition and any parameters). For example, suppose we are given an ODE

\(\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))\)

on \([t_0, t_1]\), with initial condition \(y(0) = y_0\). So \(y(t)\) is the (unknown) exact solution, to which we will compute some numerical approxiation \(y_N \approx y(t_1)\).

We may directly apply autodifferentiation to calculate \(\frac{\mathrm{d}y_N}{\mathrm{d}y_0}\), by backpropagating through the internals of the solver. This is known a "discretise then optimise", is the default in Diffrax, and corresponds to diffrax.RecursiveCheckpointAdjoint below.

Alternatively we may compute \(\frac{\mathrm{d}y(t_1)}{\mathrm{d}y_0}\) analytically. In doing so we obtain a backwards-in-time ODE that we must numerically solve to obtain the desired gradients. This is known as "optimise then discretise", and corresponds to diffrax.BacksolveAdjoint below.

diffrax.AbstractAdjoint

diffrax.AbstractAdjoint ¤

Abstract base class for all adjoint methods.

loop(self, *, args, terms, solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, throw, init_state, passed_solver_state, passed_controller_state, progress_meter) -> Any abstractmethod ¤

Runs the main solve loop. Subclasses can override this to provide custom backpropagation behaviour; see for example the implementation of diffrax.BacksolveAdjoint.

¤

¤

diffrax.RecursiveCheckpointAdjoint (AbstractAdjoint) ¤

Enables support for backpropagating through diffrax.diffeqsolve by differentiating the numerical solution directly. This is sometimes known as "discretise-then-optimise", or described as "backpropagation through the solver".

Uses a binomial checkpointing scheme to keep memory usage low.

For most problems this is the preferred technique for backpropagating through a differential equation, and as such it is the default for diffrax.diffeqsolve.

Info

Note that this cannot be forward-mode autodifferentiated. (E.g. using jax.jvp.) Try using diffrax.ForwardMode if you need forward-mode autodifferentiation, or diffrax.DirectAdjoint if you need both forward and reverse-mode autodifferentiation.

References

Selecting which steps at which to save checkpoints (and when this is done, which old checkpoint to evict) is important for minimising the amount of recomputation performed.

The implementation here performs "online checkpointing", as the number of steps is not known in advance. This was developed in:

@article{stumm2010new,
    author = {Stumm, Philipp and Walther, Andrea},
    title = {New Algorithms for Optimal Online Checkpointing},
    journal = {SIAM Journal on Scientific Computing},
    volume = {32},
    number = {2},
    pages = {836--854},
    year = {2010},
    doi = {10.1137/080742439},
}

@article{wang2009minimal,
    author = {Wang, Qiqi and Moin, Parviz and Iaccarino, Gianluca},
    title = {Minimal Repetition Dynamic Checkpointing Algorithm for Unsteady
             Adjoint Calculation},
    journal = {SIAM Journal on Scientific Computing},
    volume = {31},
    number = {4},
    pages = {2549--2567},
    year = {2009},
    doi = {10.1137/080727890},
}

For reference, the classical "offline checkpointing" (also known as "treeverse", "recursive binary checkpointing", "revolve" etc.) was developed in:

@article{griewank1992achieving,
    author = {Griewank, Andreas},
    title = {Achieving logarithmic growth of temporal and spatial complexity in
             reverse automatic differentiation},
    journal = {Optimization Methods and Software},
    volume = {1},
    number = {1},
    pages = {35--54},
    year  = {1992},
    publisher = {Taylor & Francis},
    doi = {10.1080/10556789208805505},
}

@article{griewank2000revolve,
    author = {Griewank, Andreas and Walther, Andrea},
    title = {Algorithm 799: Revolve: An Implementation of Checkpointing for the
             Reverse or Adjoint Mode of Computational Differentiation},
    year = {2000},
    publisher = {Association for Computing Machinery},
    volume = {26},
    number = {1},
    doi = {10.1145/347837.347846},
    journal = {ACM Trans. Math. Softw.},
    pages = {19--45},
}
__init__(self, checkpoints: Optional[int] = None) ¤

Arguments:

  • checkpoints: the number of checkpoints to save. The amount of memory used by the differential equation solve will be roughly equal to the number of checkpoints multiplied by the size of y0. You can speed up backpropagation by allocating more checkpoints. (So it makes sense to set as many checkpoints as you have memory for.) This value can also be set to None (the default), in which case it will be set to log(max_steps), for which a theoretical result is available guaranteeing that backpropagation will take O(n log n) time in the number of steps n <= max_steps.

You must pass either diffeqsolve(..., max_steps=...) or RecursiveCheckpointAdjoint(checkpoints=...) to be able to backpropagate; otherwise the computation will not be autodifferentiable.

diffrax.ForwardMode (AbstractAdjoint) ¤

Enables support for forward-mode automatic differentiation (like jax.jvp or jax.jacfwd) through diffrax.diffeqsolve. (As such this shouldn't really be called an 'adjoint' method -- which is a word that refers to any kind of reverse-mode autodifferentiation. Ah well.)

This is useful when we have many more outputs than inputs to a function - for instance during parameter inference for ODE models with least-squares solvers such as optimistix.LevenbergMarquardt, that operate on the residuals.

diffrax.ImplicitAdjoint (AbstractAdjoint) ¤

Backpropagate via the implicit function theorem.

This is used when solving towards a steady state, typically using diffrax.Event where the condition function is obtained by calling diffrax.steady_state_event. In this case, the output of the solver is \(y(θ)\) for which \(f(t, y(θ), θ) = 0\). (Where \(θ\) corresponds to all parameters found through terms and args, but not y0.) Then we can skip backpropagating through the solver and instead directly compute \(\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}\) via the implicit function theorem.

Observe that this involves solving a linear system with matrix given by the Jacobian df/dy.

__init__(self, linear_solver: AbstractLinearSolver = AutoLinearSolver(well_posed=None), tags: Union[object, Iterable[object]] = <factory>) ¤

Arguments:

  • linear_solver: A Lineax solver for solving the linear system.
  • tags: Any Lineax tags describing the Jacobian matrix df/dy.

diffrax.BacksolveAdjoint (AbstractAdjoint) ¤

Backpropagate through diffrax.diffeqsolve by solving the continuous adjoint equations backwards-in-time. This is also sometimes known as "optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint method".

Warning

This method is not recommended! It was popularised by this paper, and for this reason it is sometimes erroneously believed to be a better method for backpropagation than other choices available.

In practice whilst BacksolveAdjoint indeed has very low memory usage, its computed gradients will also be approximate. As the checkpointing of diffrax.RecursiveCheckpointAdjoint also gives low memory usage, then in practice that is essentially always preferred.

This will compute gradients with respect to the terms, y0 and args arguments passed to diffrax.diffeqsolve. If you attempt to compute gradients with respect to anything else (for example t0, or arguments passed via closure), then a CustomVJPException will be raised by JAX. See also this FAQ entry.

Info

Using this method prevents computing forward-mode autoderivatives of diffrax.diffeqsolve. (That is to say, jax.jvp will not work.)

__init__(self, **kwargs) ¤

Arguments:

  • **kwargs: The arguments for the diffrax.diffeqsolve operations that are called on the backward pass. For example use
    BacksolveAdjoint(solver=Dopri5())
    
    to specify a particular solver to use on the backward pass.

diffrax.DirectAdjoint (AbstractAdjoint) ¤

A variant of diffrax.RecursiveCheckpointAdjoint that is also able to support forward-mode autodifferentiation, whilst being less computationally efficient. (Under-the-hood it is using several nested layers of jax.lax.scans and jax.checkpoints, so that the cost of the solve increases with max_steps, even if you don't need that many steps to perform the solve in practice.)

Warning

This method is not recommended! In practice you should almost always use either diffrax.RecursiveCheckpointAdjoint or diffrax.ForwardMode, depending on whether you need reverse or forward mode autodifferentiation. As this method is far less computationally efficient, then in practice it is only useful if you really really need to be able to support both kinds of autodifferentiation.


diffrax.adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> Union[float, int] ¤

Defines an adjoint seminorm. This can frequently be used to increase the efficiency of backpropagation via diffrax.BacksolveAdjoint, as follows:

adjoint_controller = diffrax.PIDController(norm=diffrax.adjoint_rms_seminorm)
adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller)
diffrax.diffeqsolve(..., adjoint=adjoint)

Note that this means that any stepsize_controller specified for the forward pass will not be automatically used for the backward pass (as adjoint_controller overrides it), so you should specify any custom rtol, atol etc. for the backward pass as well.

Reference
@article{kidger2021hey,
    author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry},
    title={``{H}ey, that's not an {ODE}'': {F}aster {ODE} {A}djoints via
           {S}eminorms},
    year={2021},
    journal={International Conference on Machine Learning}
}