Adjoints¤
There are multiple ways to backpropagate through a differential equation (to compute the gradient of the solution with respect to its initial condition and any parameters).
Info
Why are there multiple ways of backpropagating through a differential equation? Suppose we are given an ODE
\(\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))\)
on \([t_0, t_1]\), with initial condition \(y(0) = y_0\). So \(y(t)\) is the (unknown) exact solution, to which we will compute some numerical approxiation \(y_N \approx y(t_1)\).
We may directly apply autodifferentiation to calculate \(\frac{\mathrm{d}y_N}{\mathrm{d}y_0}\), by backpropagating through the internals of the solver. This is known a "discretise then optimise", is the default in Diffrax, and corresponds to diffrax.RecursiveCheckpointAdjoint
below.
Alternatively we may compute \(\frac{\mathrm{d}y(t_1)}{\mathrm{d}y_0}\) analytically. In doing so we obtain a backwards-in-time ODE that we must numerically solve to obtain the desired gradients. This is known as "optimise then discretise", and corresponds to diffrax.BacksolveAdjoint
below.
diffrax.AbstractSolver
diffrax.AbstractAdjoint
¤
Abstract base class for all adjoint methods.
loop(self, *, args, terms, solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, throw, init_state, passed_solver_state, passed_controller_state)
abstractmethod
¤
Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
diffrax.BacksolveAdjoint
.
¤
¤
diffrax.RecursiveCheckpointAdjoint (AbstractAdjoint)
¤
Backpropagate through diffrax.diffeqsolve
by differentiating the numerical
solution directly. This is sometimes known as "discretise-then-optimise", or
described as "backpropagation through the solver".
For most problems this is the preferred technique for backpropagating through a differential equation.
In addition a binomial checkpointing scheme is used so that memory usage is low. (This checkpointing can increase compile time a bit, though.)
diffrax.NoAdjoint (AbstractAdjoint)
¤
Disable backpropagation through diffrax.diffeqsolve
.
Forward-mode autodifferentiation (jax.jvp
) will continue to work as normal.
If you do not need to differentiate the results of diffrax.diffeqsolve
then
this may sometimes improve the speed at which the differential equation is solved.
diffrax.ImplicitAdjoint (AbstractAdjoint)
¤
Backpropagate via the implicit function theorem.
This is used when solving towards a steady state, typically using
diffrax.SteadyStateEvent
. In this case, the output of the solver is \(y(θ)\)
for which \(f(t, y(θ), θ) = 0\). (Where \(θ\) corresponds to all parameters found
through terms
and args
, but not y0
.) Then we can skip backpropagating through
the solver and instead directly compute
\(\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}\)
via the implicit function theorem.
diffrax.BacksolveAdjoint (AbstractAdjoint)
¤
Backpropagate through diffrax.diffeqsolve
by solving the continuous
adjoint equations backwards-in-time. This is also sometimes known as
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
method".
This method implies very low memory usage, but the computed gradients will only be approximate. As such other methods are generally preferred unless exceeding memory is a concern.
This will compute gradients with respect to the terms
, y0
and args
arguments
passed to diffrax.diffeqsolve
. If you attempt to compute gradients with
respect to anything else (for example t0
, or arguments passed via closure), then
a CustomVJPException
will be raised. See also
this FAQ
entry.
Note
This was popularised by this paper. For this reason it is sometimes erroneously believed to be a better method for backpropagation than the other choices available.
Warning
Using this method prevents computing forward-mode autoderivatives of
diffrax.diffeqsolve
. (That is to say, jax.jvp
will not work.)
__init__(self, **kwargs)
¤
Arguments:
**kwargs
: The arguments for thediffrax.diffeqsolve
operations that are called on the backward pass. For example useto specify a particular solver to use on the backward pass.BacksolveAdjoint(solver=Dopri5())
diffrax.adjoint_rms_seminorm(x: Tuple[PyTree, PyTree, PyTree, PyTree]) -> Scalar
¤
Defines an adjoint seminorm. This can frequently be used to increase the
efficiency of backpropagation via diffrax.BacksolveAdjoint
, as follows:
adjoint_controller = diffrax.PIDController(norm=diffrax.adjoint_rms_seminorm)
adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller)
diffrax.diffeqsolve(..., adjoint=adjoint)
Note that this means that any stepsize_controller
specified for the forward pass
will not be automatically used for the backward pass (as adjoint_controller
overrides it), so you should specify any custom rtol
, atol
etc. for the
backward pass as well.
Reference
@article{kidger2021hey,
author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry},
title={``{H}ey, that's not an {ODE}'': {F}aster {ODE} {A}djoints via
{S}eminorms},
year={2021},
journal={International Conference on Machine Learning}
}