Skip to content

Debugging tools¤

Both Equinox and JAX provide a number of debugging tools.

Common sources of NaNs¤

A common source of NaNs on the forward pass is calling jnp.log or jnp.sqrt on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with jax.debug.print).

A common source of NaNs when backpropagating is when using one of the above operations with a jnp.where, for example y = jnp.where(x > 0, jnp.log(x), 0). In this case the NaN is created on the forward pass, but is then masked by the jnp.where. Unfortunately, when backpropagating, the order of the log and the where is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a where on both sides. For this example, safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0). This ensures that the NaN is never created in the first place at all.

Debugging runtime errors¤

If you are getting a runtime error from equinox.error_if, then you can control the on-error behaviour via the environment variable EQX_ON_ERROR. In particular, setting EQX_ON_ERROR=breakpoint will open a jax.debug.breakpoint where the error arises. See the runtime errors for more information and for other values of this environment variable.

If ran from jax.jit, then the equinox.error_if error will be a long error message starting INTERNAL: Generated function failed: CpuCallback error: RuntimeError: .... You may prefer to use eqx.filter_jit, which will remove some of the extra boilerplate from the error message.

JAX tools¤

JAX itself provides the following tools:

  • the jax.debug.print function, for printing results under JIT.
  • the jax.debug.breakpoint function, for opening a debugger under JIT.
  • the JAX_DEBUG_NANS=1 environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then try equinox.debug.backward_nan below. If the NaN occurs inside of a loop, then consider pairing this with JAX_DISABLE_JIT=1. (Many loops are implicitly jit'd.)
  • the JAX_DISABLE_JIT=1 environment variable, for running the computation without JIT. This will be much slower, so this isn't always practical.
  • the JAX_TRACEBACK_FILTERING=off environment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.)

Equinox tools¤

equinox.debug.announce_transform(x, name = None, intermediates = False, announce: Callable[[str], Any] = <built-in function print>) ¤

Identity function on an arbitrary PyTree. Announces each time a JAX transform is applied (grad, vmap, etc.).

Warning

This API is not stable. It should be used for one-off debugging purposes only.

Arguments:

  • x: a variable to intercept.
  • intermediates: whether to include intermediate transforms, that haven't yet finished being transformed. E.g. if intermediates=False, then jit(vmap(...)) will print out
    vmap:abstract
    vmap:mlir`
    
    whilst if intermediates=True, then jit(vmap(...)) will print out
    vmap
    vmap:abstract
    vmap:mlir`
    
  • announce: the function to announce via. Defaults to just print.

Returns:

The x argument is returned unchanged.

As a side-effect, the transforms applied to x will be printed out.


equinox.debug.backward_nan(x, name = None, terminate = True) ¤

Debug NaNs that only occur on the backward pass.

Arguments:

  • x: a variable to intercept.
  • name: an optional name to appear in printed debug statements.
  • terminate: whether to halt the computation if a NaN cotangent is found. If True then an error will be raised via equinox.error_if. (So you can also arrange for a breakpoint to trigger by setting EQX_ON_ERROR appropriately.)

Returns:

The x argument is returned unchanged.

As a side-effect, both the primal and the cotangent for x will be printed out during the backward pass.


equinox.debug.breakpoint_if(pred: Array, **kwargs) ¤

As jax.debug.breakpoint, but only triggers if pred is True.

Arguments:

  • pred: the predicate for whether to trigger the breakpoint.
  • **kwargs: any other keyword arguments to forward to jax.debug.breakpoint.

Returns:

Nothing.


equinox.debug.store_dce(x: PyTree, name: Hashable = None) ¤

Used to check whether a PyTree is DCE'd. (That is, whether this code has been removed in the compiler, due to dead code elimination.)

store_dce must be used within a JIT'd function, and acts as the identity function. When the JIT'd function is called, then whether each array got DCE'd or not is recorded. This can subsequently be inspected using inspect_dce.

Any non-arrays in x are ignored.

Example

@jax.jit
def f(x):
    a, _ = eqx.debug.store_dce((x**2, x + 1))
    return a

f(1)
eqx.debug.inspect_dce()
# Found 1 call to `equinox.debug.store_dce`.
# Entry 0:
# (i32[], <DCE'd>)

Arguments:

  • x: Any PyTree. Its arrays are checked for being DCE'd.
  • name: Optional argument. Any hashable value used to distinguish this call site from another call site. If used, then it should be passed to inspect_dce to print only those entries with this name.

Returns:

x is returned unchanged.

equinox.debug.inspect_dce(name: Hashable = None) ¤

Used in conjunction with equinox.debug.check_dce; see documentation there.

Must be called outside of any JIT'd function.

Arguments:

  • name: Optional argument. Whatever name was used with check_dce.

Returns:

Nothing. DCE information is printed to stdout.


equinox.debug.assert_max_traces(fn: Callable = sentinel, *, max_traces: Optional[int]) ¤

Asserts that the wrapped callable is not called more than max_traces times.

The typical use-case for this is to check that a JIT-compiled function is not compiled more than max_traces times. (I.e. this function can be used to guard against bugs.) In this case it should be placed within the JIT wrapper.

Example

@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(x, y):
    return x + y

Arguments:

  • fn: The callable to wrap.
  • max_traces: keyword only argument. The maximum number of calls that are allowed. Can be None to allow arbitrarily many calls; in this case the number of calls can still can be found via equinox.debug.get_num_traces.

Returns:

A wrapped version of fn that tracks the number of times it is called. This will raise a RuntimeError if is called more than max_traces many times.

Info

See also chex.assert_max_traces which provides similar functionality.

The differences are that (a) Chex's implementation is a bit stricter, as the following will raise:

import chex
import jax

def f(x):
    pass

f2 = jax.jit(chex.assert_max_traces(f, 1))
f3 = jax.jit(chex.assert_max_traces(f, 1))

f2(1)
f3(1)  # will raise, despite the fact that f2 and f3 are different.
and (b) Equinox's implementation supports non-function callables.

You may prefer the Chex version if you prefer the stricter raising behaviour of the above code.

equinox.debug.get_num_traces(fn) -> int ¤

Given a function wrapped in equinox.debug.assert_max_traces, return the number of times which it has been traced so far.