Debugging tools¤
Both Equinox and JAX provide a number of debugging tools.
Common sources of NaNs¤
A common source of NaNs on the forward pass is calling jnp.log or jnp.sqrt on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with jax.debug.print).
A common source of NaNs when backpropagating is when using one of the above operations with a jnp.where, for example y = jnp.where(x > 0, jnp.log(x), 0). In this case the NaN is created on the forward pass, but is then masked by the jnp.where. Unfortunately, when backpropagating, the order of the log and the where is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a where on both sides. For this example, safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0). This ensures that the NaN is never created in the first place at all.
Debugging runtime errors¤
If you are getting a runtime error from equinox.error_if, then you can control the on-error behaviour via the environment variable EQX_ON_ERROR. In particular, setting EQX_ON_ERROR=breakpoint will open a jax.debug.breakpoint where the error arises. See the runtime errors for more information and for other values of this environment variable.
If ran from jax.jit, then the equinox.error_if error will be a long error message starting INTERNAL: Generated function failed: CpuCallback error: RuntimeError: .... You may prefer to use eqx.filter_jit, which will remove some of the extra boilerplate from the error message.
JAX tools¤
JAX itself provides the following tools:
- the jax.debug.printfunction, for printing results under JIT.
- the jax.debug.breakpointfunction, for opening a debugger under JIT.
- the JAX_DEBUG_NANS=1environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then tryequinox.debug.backward_nanbelow. If the NaN occurs inside of a loop, then consider pairing this withJAX_DISABLE_JIT=1. (Many loops are implicitly jit'd.)
- the JAX_DISABLE_JIT=1environment variable, for running the computation without JIT. This will be much slower, so this isn't always practical.
- the JAX_TRACEBACK_FILTERING=offenvironment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.)
Equinox tools¤
  equinox.debug.announce_transform(x, name=None, intermediates=False, announce: Callable[[str], Any] = <built-in function print>)
  
    Identity function on an arbitrary PyTree. Announces each time a JAX transform is applied (grad, vmap, etc.).
Warning
This API is not stable. It should be used for one-off debugging purposes only.
Arguments:
- x: a variable to intercept.
- intermediates: whether to include intermediate transforms, that haven't yet finished being transformed. E.g. if- intermediates=False, then- jit(vmap(...))will print outwhilst if- vmap:abstract vmap:mlir`- intermediates=True, then- jit(vmap(...))will print out- vmap vmap:abstract vmap:mlir`
- announce: the function to announce via. Defaults to just- print.
Returns:
The x argument is returned unchanged.
As a side-effect, the transforms applied to x will be printed out.
  equinox.debug.backward_nan(x, name=None, terminate=True)
  
    Debug NaNs that only occur on the backward pass.
Arguments:
- x: a variable to intercept.
- name: an optional name to appear in printed debug statements.
- terminate: whether to halt the computation if a NaN cotangent is found. If- Truethen an error will be raised via- equinox.error_if. (So you can also arrange for a breakpoint to trigger by setting- EQX_ON_ERRORappropriately.)
Returns:
The x argument is returned unchanged.
As a side-effect, both the primal and the cotangent for x will be printed out
during the backward pass.
  equinox.debug.breakpoint_if(pred: Bool[Array, '...'], **kwargs)
  
    As jax.debug.breakpoint, but only triggers if pred is True.
Arguments:
- pred: the predicate for whether to trigger the breakpoint.
- **kwargs: any other keyword arguments to forward to- jax.debug.breakpoint.
Returns:
Nothing.
  equinox.debug.store_dce(x: PyTree, name: Hashable = None)
  
    Used to check whether a PyTree is DCE'd. (That is, whether this code has been removed in the compiler, due to dead code elimination.)
store_dce must be used within a JIT'd function, and acts as the identity
function. When the JIT'd function is called, then whether each array got DCE'd or
not is recorded. This can subsequently be inspected using inspect_dce.
Any non-arrays in x are ignored.
Example
@jax.jit
def f(x):
    a, _ = eqx.debug.store_dce((x**2, x + 1))
    return a
f(1)
eqx.debug.inspect_dce()
# Found 1 call to `equinox.debug.store_dce`.
# Entry 0:
# (i32[], <DCE'd>)
Arguments:
- x: Any PyTree. Its arrays are checked for being DCE'd.
- name: Optional argument. Any hashable value used to distinguish this call site from another call site. If used, then it should be passed to- inspect_dceto print only those entries with this name.
Returns:
x is returned unchanged.
  equinox.debug.inspect_dce(name: Hashable = None)
  
    Used in conjunction with equinox.debug.check_dce; see documentation there.
Must be called outside of any JIT'd function.
Arguments:
- name: Optional argument. Whatever name was used with- check_dce.
Returns:
Nothing. DCE information is printed to stdout.
  equinox.debug.assert_max_traces(fn: Callable = sentinel, *, max_traces: int | None)
  
    Asserts that the wrapped callable is not called more than max_traces times.
The typical use-case for this is to check that a JIT-compiled function is not
compiled more than max_traces times. (I.e. this function can be used to guard
against bugs.) In this case it should be placed within the JIT wrapper.
Example
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(x, y):
    return x + y
Arguments:
- fn: The callable to wrap.
- max_traces: keyword only argument. The maximum number of calls that are allowed. Can be- Noneto allow arbitrarily many calls; in this case the number of calls can still can be found via- equinox.debug.get_num_traces.
Returns:
A wrapped version of fn that tracks the number of times it is called. This will
raise a RuntimeError if is called more than max_traces many times.
Info
See also
chex.assert_max_traces
which provides similar functionality.
The differences are that (a) Chex's implementation is a bit stricter, as the following will raise:
import chex
import jax
def f(x):
    pass
f2 = jax.jit(chex.assert_max_traces(f, 1))
f3 = jax.jit(chex.assert_max_traces(f, 1))
f2(1)
f3(1)  # will raise, despite the fact that f2 and f3 are different.
You may prefer the Chex version if you prefer the stricter raising behaviour of the above code.
  equinox.debug.get_num_traces(fn) -> int
  
    Given a function wrapped in equinox.debug.assert_max_traces, return the
number of times which it has been traced so far.