Both Equinox and JAX provide a number of debugging tools.
Common sources of NaNs¤
A common source of NaNs on the forward pass is calling
jnp.sqrt on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with
A common source of NaNs when backpropagating is when using one of the above operations with a
jnp.where, for example
y = jnp.where(x > 0, jnp.log(x), 0). In this case the NaN is created on the forward pass, but is then masked by the
jnp.where. Unfortunately, when backpropagating, the order of the
log and the
where is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a
where on both sides. For this example,
safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0). This ensures that the NaN is never created in the first place at all.
Debugging runtime errors¤
If you are getting a runtime error from
equinox.error_if, then you can control the on-error behaviour via the environment variable
EQX_ON_ERROR. In particular, setting
EQX_ON_ERROR=breakpoint will open a
jax.debug.breakpoint where the error arises. See the runtime errors for more information and for other values of this environment variable.
If ran from
jax.jit, then the
equinox.error_if error will be a long error message starting
INTERNAL: Generated function failed: CpuCallback error: RuntimeError: .... You may prefer to use
eqx.filter_jit, which will remove some of the extra boilerplate from the error message.
JAX itself provides the following tools:
jax.debug.printfunction, for printing results under JIT.
jax.debug.breakpointfunction, for opening a debugger under JIT.
JAX_DEBUG_NANS=1environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then try
equinox.debug.backward_nanbelow. If the NaN occurs inside of a loop, then consider pairing this with
JAX_DISABLE_JIT=1. (Many loops are implicitly jit'd.)
JAX_DISABLE_JIT=1environment variable, for running the computation without JIT. This will be much slower, so this isn't always practical.
JAX_TRACEBACK_FILTERING=offenvironment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.)
equinox.debug.announce_transform(x, name = None, intermediates = False, announce: Callable[[str], Any] = <built-in function print>)
Identity function on an arbitrary PyTree. Announces each time a JAX transform is applied (grad, vmap, etc.).
This API is not stable. It should be used for one-off debugging purposes only.
x: a variable to intercept.
intermediates: whether to include intermediate transforms, that haven't yet finished being transformed. E.g. if
jit(vmap(...))will print outwhilst if
jit(vmap(...))will print out
vmap vmap:abstract vmap:mlir`
announce: the function to announce via. Defaults to just
x argument is returned unchanged.
As a side-effect, the transforms applied to
x will be printed out.
equinox.debug.backward_nan(x, name = None, terminate = True)
Debug NaNs that only occur on the backward pass.
x: a variable to intercept.
name: an optional name to appear in printed debug statements.
terminate: whether to halt the computation if a NaN cotangent is found. If
Truethen an error will be raised via
equinox.error_if. (So you can also arrange for a breakpoint to trigger by setting
x argument is returned unchanged.
As a side-effect, both the primal and the cotangent for
x will be printed out
during the backward pass.
equinox.debug.breakpoint_if(pred: Array, **kwargs)
jax.debug.breakpoint, but only triggers if
pred is True.
pred: the predicate for whether to trigger the breakpoint.
**kwargs: any other keyword arguments to forward to
equinox.debug.store_dce(x: PyTree, name: Hashable = None)
Used to check whether a PyTree is DCE'd. (That is, whether this code has been removed in the compiler, due to dead code eliminitation.)
store_dce must be used within a JIT'd function, and acts as the identity
function. When the JIT'd function is called, then whether each array got DCE'd or
not is recorded. This can subsequently be inspected using
Any non-arrays in
x are ignored.
@jax.jit def f(x): a, _ = eqxi.store_dce((x**2, x + 1)) return a f(1) eqxi.inspect_dce() # Found 1 call to `equinox.debug.store_dce`. # Entry 0: # (i32, <DCE'd>)
x: Any PyTree. Its arrays are checked for being DCE'd.
name: Optional argument. Any hashable value used to distinguish this call site from another call site. If used, then it should be passed to
inspect_dceto print only those entries with this name.
x is returned unchanged.
equinox.debug.inspect_dce(name: Hashable = None)
Used in conjunction with
equinox.debug.check_dce; see documentation there.
Must be called outside of any JIT'd function.
name: Optional argument. Whatever name was used with
Nothing. DCE information is printed to stdout.
equinox.debug.assert_max_traces(fn: Callable = sentinel, *, max_traces: Optional[int])
Asserts that the wrapped callable is not called more than
The typical use-case for this is to check that a JIT-compiled function is not
compiled more than
max_traces times. (I.e. this function can be used to guard
against bugs.) In this case it should be placed within the JIT wrapper.
@eqx.filter_jit @eqx.debug.assert_max_traces(max_traces=1) def f(x, y): return x + y
fn: The callable to wrap.
max_traces: keyword only argument. The maximum number of calls that are allowed. Can be
Noneto allow arbitrarily many calls; in this case the number of calls can still can be found via
A wrapped version of
fn that tracks the number of times it is called. This will
RuntimeError if is called more than
max_traces many times.
which provides similar functionality.
The differences are that (a) Chex's implementation is a bit stricter, as the following will raise:
import chex import jax def f(x): pass f2 = jax.jit(chex.assert_max_traces(f, 1)) f3 = jax.jit(chex.assert_max_traces(f, 1)) f2(1) f3(1) # will raise, despite the fact that f2 and f3 are different.
You may prefer the Chex version if you prefer the stricter raising behaviour of the above code.