Debugging tools¤
Both Equinox and JAX provide a number of debugging tools.
Common sources of NaNs¤
A common source of NaNs on the forward pass is calling jnp.log or jnp.sqrt on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with jax.debug.print).
A common source of NaNs when backpropagating is when using one of the above operations with a jnp.where, for example y = jnp.where(x > 0, jnp.log(x), 0). In this case the NaN is created on the forward pass, but is then masked by the jnp.where. Unfortunately, when backpropagating, the order of the log and the where is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a where on both sides. For this example, safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0). This ensures that the NaN is never created in the first place at all.
Debugging runtime errors¤
If you are getting a runtime error from equinox.error_if, then you can control the on-error behaviour via the environment variable EQX_ON_ERROR. In particular, setting EQX_ON_ERROR=breakpoint will open a jax.debug.breakpoint where the error arises. See the runtime errors for more information and for other values of this environment variable.
If ran from jax.jit, then the equinox.error_if error will be a long error message starting INTERNAL: Generated function failed: CpuCallback error: RuntimeError: .... You may prefer to use eqx.filter_jit, which will remove some of the extra boilerplate from the error message.
JAX tools¤
JAX itself provides the following tools:
- the
jax.debug.printfunction, for printing results under JIT. - the
jax.debug.breakpointfunction, for opening a debugger under JIT. - the
JAX_DEBUG_NANS=1environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then tryequinox.debug.backward_nanbelow. If the NaN occurs inside of a loop, then consider pairing this withJAX_DISABLE_JIT=1. (Many loops are implicitly jit'd.) - the
JAX_DISABLE_JIT=1environment variable, for running the computation without JIT. This will be much slower, so this isn't always practical. - the
JAX_TRACEBACK_FILTERING=offenvironment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.)
Equinox tools¤
equinox.debug.announce_transform(x, name=None, intermediates=False, announce: Callable[[str], Any] = <built-in function print>)
¤
Identity function on an arbitrary PyTree. Announces each time a JAX transform is applied (grad, vmap, etc.).
Warning
This API is not stable. It should be used for one-off debugging purposes only.
Arguments:
x: a variable to intercept.intermediates: whether to include intermediate transforms, that haven't yet finished being transformed. E.g. ifintermediates=False, thenjit(vmap(...))will print outwhilst ifvmap:abstract vmap:mlir`intermediates=True, thenjit(vmap(...))will print outvmap vmap:abstract vmap:mlir`announce: the function to announce via. Defaults to justprint.
Returns:
The x argument is returned unchanged.
As a side-effect, the transforms applied to x will be printed out.
equinox.debug.backward_nan(x, name=None, terminate=True)
¤
Debug NaNs that only occur on the backward pass.
Arguments:
x: a variable to intercept.name: an optional name to appear in printed debug statements.terminate: whether to halt the computation if a NaN cotangent is found. IfTruethen an error will be raised viaequinox.error_if. (So you can also arrange for a breakpoint to trigger by settingEQX_ON_ERRORappropriately.)
Returns:
The x argument is returned unchanged.
As a side-effect, both the primal and the cotangent for x will be printed out
during the backward pass.
equinox.debug.breakpoint_if(pred: Bool[Array, '...'], **kwargs)
¤
As jax.debug.breakpoint, but only triggers if pred is True.
Arguments:
pred: the predicate for whether to trigger the breakpoint.**kwargs: any other keyword arguments to forward tojax.debug.breakpoint.
Returns:
Nothing.
equinox.debug.store_dce(x: PyTree, name: Hashable = None)
¤
Used to check whether a PyTree is DCE'd. (That is, whether this code has been removed in the compiler, due to dead code elimination.)
store_dce must be used within a JIT'd function, and acts as the identity
function. When the JIT'd function is called, then whether each array got DCE'd or
not is recorded. This can subsequently be inspected using inspect_dce.
Any non-arrays in x are ignored.
Example
@jax.jit
def f(x):
a, _ = eqx.debug.store_dce((x**2, x + 1))
return a
f(1)
eqx.debug.inspect_dce()
# Found 1 call to `equinox.debug.store_dce`.
# Entry 0:
# (i32[], <DCE'd>)
Arguments:
x: Any PyTree. Its arrays are checked for being DCE'd.name: Optional argument. Any hashable value used to distinguish this call site from another call site. If used, then it should be passed toinspect_dceto print only those entries with this name.
Returns:
x is returned unchanged.
equinox.debug.inspect_dce(name: Hashable = None)
¤
Used in conjunction with equinox.debug.check_dce; see documentation there.
Must be called outside of any JIT'd function.
Arguments:
name: Optional argument. Whatever name was used withcheck_dce.
Returns:
Nothing. DCE information is printed to stdout.
equinox.debug.assert_max_traces(fn: Callable = sentinel, *, max_traces: int | None)
¤
Asserts that the wrapped callable is not called more than max_traces times.
The typical use-case for this is to check that a JIT-compiled function is not
compiled more than max_traces times. (I.e. this function can be used to guard
against bugs.) In this case it should be placed within the JIT wrapper.
Example
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(x, y):
return x + y
Arguments:
fn: The callable to wrap.max_traces: keyword only argument. The maximum number of calls that are allowed. Can beNoneto allow arbitrarily many calls; in this case the number of calls can still can be found viaequinox.debug.get_num_traces.
Returns:
A wrapped version of fn that tracks the number of times it is called. This will
raise a RuntimeError if is called more than max_traces many times.
Info
See also
chex.assert_max_traces
which provides similar functionality.
The differences are that (a) Chex's implementation is a bit stricter, as the following will raise:
import chex
import jax
def f(x):
pass
f2 = jax.jit(chex.assert_max_traces(f, 1))
f3 = jax.jit(chex.assert_max_traces(f, 1))
f2(1)
f3(1) # will raise, despite the fact that f2 and f3 are different.
You may prefer the Chex version if you prefer the stricter raising behaviour of the above code.
equinox.debug.get_num_traces(fn) -> int
¤
Given a function wrapped in equinox.debug.assert_max_traces, return the
number of times which it has been traced so far.