# FAQ¤

### Compilation is taking a long time.¤

If you're using a Runge--Kutta method like diffrax.Dopri5 etc., then try setting scan_stages=True when initialisating the solver, for example Dopri5(scan_stages=True). This will substantially reduce compile time at the expense of a slightly slower run time.

### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour.¤

Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) See here.

### How does this compare to jax.experimental.ode.odeint?¤

The equivalent solver in Diffrax is:

diffeqsolve(
...,
solver=Dopri5(scan_stages=True),
stepsize_controller=PIDController(rtol=1.4e-8, atol=1.4e-8),
max_steps=None,
)


In practice, TSit5 is usually a better solver than Dopri5. And the default adjoint method (RecursiveCheckpointAdjoint) is usually a better choice than BacksolveAdjoint.

### I'm getting a CustomVJPException.¤

This can happen if you use diffrax.BacksolveAdjoint incorrectly.

• Everything in the args PyTree passed to diffeqsolve(..., args=args);
• Everything in the y0 PyTree passed to diffeqsolve(..., y0=y0).
• Everything in the terms PyTree passed to diffeqsolve(terms, ...).

Attempting to compute gradients with respect to anything else will result in this exception.

Example

Here is a minimal example of wrong code that will raise this exception.

from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr

mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))

@eqx.filter_jit
def run(model):
def f(t, y, args):  # model captured via closure; is not part of the terms PyTree.
return model(y)
sol = diffeqsolve(ODETerm(f), Euler(), 0, 1, 0.1, jnp.array([1.0]),
return jnp.sum(sol.ys)

run(mlp)


Example

The corrected version of the previous example is as follows. In this case, the model is properly part of the PyTree structure of terms.

from diffrax import BacksolveAdjoint, diffeqsolve, Euler, ODETerm
import equinox as eqx
import jax.numpy as jnp
import jax.random as jr

mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))

class VectorField(eqx.Module):
model: eqx.Module

def __call__(self, t, y, args):
return self.model(y)

@eqx.filter_jit