# Adjoints¤

There are multiple ways to backpropagate through a differential equation (to compute the gradient of the solution with respect to its initial condition and any parameters).

Info

Why are there multiple ways of backpropagating through a differential equation? Suppose we are given an ODE

\(\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t))\)

on \([t_0, t_1]\), with initial condition \(y(0) = y_0\). So \(y(t)\) is the (unknown) exact solution, to which we will compute some numerical approxiation \(y_N \approx y(t_1)\).

We may directly apply autodifferentiation to calculate \(\frac{\mathrm{d}y_N}{\mathrm{d}y_0}\), by backpropagating through the internals of the solver. This is known a "discretise then optimise", is the default in Diffrax, and corresponds to `diffrax.RecursiveCheckpointAdjoint`

below.

Alternatively we may compute \(\frac{\mathrm{d}y(t_1)}{\mathrm{d}y_0}\) analytically. In doing so we obtain a backwards-in-time ODE that we must numerically solve to obtain the desired gradients. This is known as "optimise then discretise", and corresponds to `diffrax.BacksolveAdjoint`

below.

`diffrax.AbstractAdjoint`

####
```
diffrax.AbstractAdjoint
```

¤

Abstract base class for all adjoint methods.

#####
`loop(self, *, args, terms, solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, throw, init_state, passed_solver_state, passed_controller_state) -> Any`

`abstractmethod`

¤

Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
`diffrax.BacksolveAdjoint`

.

#### ¤

##### ¤

Of the following options, `diffrax.RecursiveCheckpointAdjoint`

and `diffrax.BacksolveAdjoint`

can only be reverse-mode autodifferentiated. `diffrax.ImplicitAdjoint`

and `diffrax.DirectAdjoint`

support both forward and reverse-mode autodifferentiation.

####
```
diffrax.RecursiveCheckpointAdjoint (AbstractAdjoint)
```

¤

Backpropagate through `diffrax.diffeqsolve`

by differentiating the numerical
solution directly. This is sometimes known as "discretise-then-optimise", or
described as "backpropagation through the solver".

Uses a binomial checkpointing scheme to keep memory usage low.

For most problems this is the preferred technique for backpropagating through a differential equation.

Info

Note that this cannot be forward-mode autodifferentiated. (E.g. using
`jax.jvp`

.) Try using `diffrax.DirectAdjoint`

if that is something you need.

## References

Selecting which steps at which to save checkpoints (and when this is done, which old checkpoint to evict) is important for minimising the amount of recomputation performed.

The implementation here performs "online checkpointing", as the number of steps is not known in advance. This was developed in:

```
@article{stumm2010new,
author = {Stumm, Philipp and Walther, Andrea},
title = {New Algorithms for Optimal Online Checkpointing},
journal = {SIAM Journal on Scientific Computing},
volume = {32},
number = {2},
pages = {836--854},
year = {2010},
doi = {10.1137/080742439},
}
@article{wang2009minimal,
author = {Wang, Qiqi and Moin, Parviz and Iaccarino, Gianluca},
title = {Minimal Repetition Dynamic Checkpointing Algorithm for Unsteady
Adjoint Calculation},
journal = {SIAM Journal on Scientific Computing},
volume = {31},
number = {4},
pages = {2549--2567},
year = {2009},
doi = {10.1137/080727890},
}
```

For reference, the classical "offline checkpointing" (also known as "treeverse", "recursive binary checkpointing", "revolve" etc.) was developed in:

```
@article{griewank1992achieving,
author = {Griewank, Andreas},
title = {Achieving logarithmic growth of temporal and spatial complexity in
reverse automatic differentiation},
journal = {Optimization Methods and Software},
volume = {1},
number = {1},
pages = {35--54},
year = {1992},
publisher = {Taylor & Francis},
doi = {10.1080/10556789208805505},
}
@article{griewank2000revolve,
author = {Griewank, Andreas and Walther, Andrea},
title = {Algorithm 799: Revolve: An Implementation of Checkpointing for the
Reverse or Adjoint Mode of Computational Differentiation},
year = {2000},
publisher = {Association for Computing Machinery},
volume = {26},
number = {1},
doi = {10.1145/347837.347846},
journal = {ACM Trans. Math. Softw.},
pages = {19--45},
}
```

#####
`__init__(self, checkpoints: Optional[int] = None)`

¤

**Arguments:**

`checkpoints`

: the number of checkpoints to save. The amount of memory used by the differential equation solve will be roughly equal to the number of checkpoints multiplied by the size of`y0`

. You can speed up backpropagation by allocating more checkpoints. (So it makes sense to set as many checkpoints as you have memory for.) This value can also be set to`None`

(the default), in which case it will be set to`log(max_steps)`

, for which a theoretical result is available guaranteeing that backpropagation will take`O(n log n)`

time in the number of steps`n <= max_steps`

.

You must pass either `diffeqsolve(..., max_steps=...)`

or
`RecursiveCheckpointAdjoint(checkpoints=...)`

to be able to backpropagate; otherwise
the computation will not be autodifferentiable.

####
```
diffrax.BacksolveAdjoint (AbstractAdjoint)
```

¤

Backpropagate through `diffrax.diffeqsolve`

by solving the continuous
adjoint equations backwards-in-time. This is also sometimes known as
"optimise-then-discretise", the "continuous adjoint method" or simply the "adjoint
method".

This method implies very low memory usage, but the computed gradients will only be approximate. As such other methods are generally preferred unless exceeding memory is a concern.

This will compute gradients with respect to the `terms`

, `y0`

and `args`

arguments
passed to `diffrax.diffeqsolve`

. If you attempt to compute gradients with
respect to anything else (for example `t0`

, or arguments passed via closure), then
a `CustomVJPException`

will be raised. See also
this FAQ
entry.

Note

This was popularised by this paper. For this reason it is sometimes erroneously believed to be a better method for backpropagation than the other choices available.

Warning

Using this method prevents computing forward-mode autoderivatives of
`diffrax.diffeqsolve`

. (That is to say, `jax.jvp`

will not work.)

#####
`__init__(self, **kwargs)`

¤

**Arguments:**

`**kwargs`

: The arguments for the`diffrax.diffeqsolve`

operations that are called on the backward pass. For example useto specify a particular solver to use on the backward pass.`BacksolveAdjoint(solver=Dopri5())`

####
```
diffrax.ImplicitAdjoint (AbstractAdjoint)
```

¤

Backpropagate via the implicit function theorem.

This is used when solving towards a steady state, typically using
`diffrax.SteadyStateEvent`

. In this case, the output of the solver is \(y(θ)\)
for which \(f(t, y(θ), θ) = 0\). (Where \(θ\) corresponds to all parameters found
through `terms`

and `args`

, but not `y0`

.) Then we can skip backpropagating through
the solver and instead directly compute
\(\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}\)
via the implicit function theorem.

Observe that this involves solving a linear system with matrix given by the Jacobian
`df/dy`

.

####
```
diffrax.DirectAdjoint (AbstractAdjoint)
```

¤

A variant of `diffrax.RecursiveCheckpointAdjoint`

. The differences are that
`DirectAdjoint`

:

- Is less time+memory efficient at reverse-mode autodifferentiation (specifically,
these will increase every time
`max_steps`

increases passes a power of 16); - Cannot be reverse-mode autodifferentated if
`max_steps is None`

; - Supports forward-mode autodifferentiation.

So unless you need forward-mode autodifferentiation then
`diffrax.RecursiveCheckpointAdjoint`

should be preferred.

####
`diffrax.adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> Union[float, int]`

¤

Defines an adjoint seminorm. This can frequently be used to increase the
efficiency of backpropagation via `diffrax.BacksolveAdjoint`

, as follows:

```
adjoint_controller = diffrax.PIDController(norm=diffrax.adjoint_rms_seminorm)
adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller)
diffrax.diffeqsolve(..., adjoint=adjoint)
```

Note that this means that any `stepsize_controller`

specified for the forward pass
will not be automatically used for the backward pass (as `adjoint_controller`

overrides it), so you should specify any custom `rtol`

, `atol`

etc. for the
backward pass as well.

## Reference

```
@article{kidger2021hey,
author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry},
title={``{H}ey, that's not an {ODE}'': {F}aster {ODE} {A}djoints via
{S}eminorms},
year={2021},
journal={International Conference on Machine Learning}
}
```