# FAQ¤

### Compilation is taking a long time.¤

If you're using a Runge--Kutta method like `diffrax.Dopri5`

etc., then try setting `scan_stages=True`

when initialisating the solver, for example `Dopri5(scan_stages=True)`

. This will substantially reduce compile time at the expense of a slightly slower run time.

### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour.¤

Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) See here.

### How does this compare to `jax.experimental.ode.odeint`

?¤

The equivalent solver in Diffrax is:

```
diffeqsolve(
...,
solver=Dopri5(scan_stages=True),
stepsize_controller=PIDController(rtol=1.4e-8, atol=1.4e-8),
adjoint=BacksolveAdjoint(),
max_steps=None,
)
```

In practice, `TSit5`

is usually a better solver than `Dopri5`

. And the default adjoint method (`RecursiveCheckpointAdjoint`

) is usually a better choice than `BacksolveAdjoint`

.

### I'm getting a `CustomVJPException`

.¤

This can happen if you use `diffrax.BacksolveAdjoint`

incorrectly.

Gradients will be computed for:

- Everything in the
`args`

PyTree passed to`diffeqsolve(..., args=args)`

; - Everything in the
`y0`

PyTree passed to`diffeqsolve(..., y0=y0)`

. - Everything in the
`terms`

PyTree passed to`diffeqsolve(terms, ...)`

.

Attempting to compute gradients with respect to anything else will result in this exception.

Example

Here is a minimal example of **wrong** code that will raise this exception.

```
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))
@eqx.filter_jit
@eqx.filter_value_and_grad
def run(model):
def f(t, y, args): # `model` captured via closure; is not part of the `terms` PyTree.
return model(y)
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]),
adjoint=BacksolveAdjoint())
return jnp.sum(sol.ys)
run(mlp)
```

Example

The corrected version of the previous example is as follows. In this case, the model is properly part of the PyTree structure of `terms`

.

```
from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))
class VectorField(eqx.Module):
model: eqx.Module
def __call__(self, t, y, args):
return self.model(y)
@eqx.filter_jit
@eqx.filter_value_and_grad
def run(model):
f = VectorField(model)
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]), adjoint=BacksolveAdjoint())
return jnp.sum(sol.ys)
run(mlp)
```