Debugging tools¤
Both Equinox and JAX provide a number of debugging tools.
Common sources of NaNs¤
A common source of NaNs on the forward pass is calling jnp.log
or jnp.sqrt
on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with jax.debug.print
).
A common source of NaNs when backpropagating is when using one of the above operations with a jnp.where
, for example y = jnp.where(x > 0, jnp.log(x), 0)
. In this case the NaN is created on the forward pass, but is then masked by the jnp.where
. Unfortunately, when backpropagating, the order of the log
and the where
is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a where
on both sides. For this example, safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0)
. This ensures that the NaN is never created in the first place at all.
Debugging runtime errors¤
If you are getting a runtime error from equinox.error_if
, then you can control the on-error behaviour via the environment variable EQX_ON_ERROR
. In particular, setting EQX_ON_ERROR=breakpoint
will open a jax.debug.breakpoint
where the error arises. See the runtime errors for more information and for other values of this environment variable.
If ran from jax.jit
, then the equinox.error_if
error will be a long error message starting INTERNAL: Generated function failed: CpuCallback error: RuntimeError: ...
. You may prefer to use eqx.filter_jit
, which will remove some of the extra boilerplate from the error message.
JAX tools¤
JAX itself provides the following tools:
- the
jax.debug.print
function, for printing results under JIT. - the
jax.debug.breakpoint
function, for opening a debugger under JIT. - the
JAX_DEBUG_NANS=1
environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then tryequinox.debug.backward_nan
below. If the NaN occurs inside of a loop, then consider pairing this withJAX_DISABLE_JIT=1
. (Many loops are implicitly jit'd.) - the
JAX_DISABLE_JIT=1
environment variable, for running the computation without JIT. This will be much slower, so this isn't always practical. - the
JAX_TRACEBACK_FILTERING=off
environment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.)
Equinox tools¤
equinox.debug.announce_transform(x, name = None, intermediates = False, announce: Callable[[str], Any] = <built-in function print>)
¤
Identity function on an arbitrary PyTree. Announces each time a JAX transform is applied (grad, vmap, etc.).
Warning
This API is not stable. It should be used for one-off debugging purposes only.
Arguments:
x
: a variable to intercept.intermediates
: whether to include intermediate transforms, that haven't yet finished being transformed. E.g. ifintermediates=False
, thenjit(vmap(...))
will print outwhilst ifvmap:abstract vmap:mlir`
intermediates=True
, thenjit(vmap(...))
will print outvmap vmap:abstract vmap:mlir`
announce
: the function to announce via. Defaults to justprint
.
Returns:
The x
argument is returned unchanged.
As a side-effect, the transforms applied to x
will be printed out.
equinox.debug.backward_nan(x, name = None, terminate = True)
¤
Debug NaNs that only occur on the backward pass.
Arguments:
x
: a variable to intercept.name
: an optional name to appear in printed debug statements.terminate
: whether to halt the computation if a NaN cotangent is found. IfTrue
then an error will be raised viaequinox.error_if
. (So you can also arrange for a breakpoint to trigger by settingEQX_ON_ERROR
appropriately.)
Returns:
The x
argument is returned unchanged.
As a side-effect, both the primal and the cotangent for x
will be printed out
during the backward pass.
equinox.debug.breakpoint_if(pred: Array, **kwargs)
¤
As jax.debug.breakpoint
, but only triggers if pred
is True.
Arguments:
pred
: the predicate for whether to trigger the breakpoint.**kwargs
: any other keyword arguments to forward tojax.debug.breakpoint
.
Returns:
Nothing.
equinox.debug.store_dce(x: PyTree, name: Hashable = None)
¤
Used to check whether a PyTree is DCE'd. (That is, whether this code has been removed in the compiler, due to dead code elimination.)
store_dce
must be used within a JIT'd function, and acts as the identity
function. When the JIT'd function is called, then whether each array got DCE'd or
not is recorded. This can subsequently be inspected using inspect_dce
.
Any non-arrays in x
are ignored.
Example
@jax.jit
def f(x):
a, _ = eqx.debug.store_dce((x**2, x + 1))
return a
f(1)
eqx.debug.inspect_dce()
# Found 1 call to `equinox.debug.store_dce`.
# Entry 0:
# (i32[], <DCE'd>)
Arguments:
x
: Any PyTree. Its arrays are checked for being DCE'd.name
: Optional argument. Any hashable value used to distinguish this call site from another call site. If used, then it should be passed toinspect_dce
to print only those entries with this name.
Returns:
x
is returned unchanged.
equinox.debug.inspect_dce(name: Hashable = None)
¤
Used in conjunction with equinox.debug.check_dce
; see documentation there.
Must be called outside of any JIT'd function.
Arguments:
name
: Optional argument. Whatever name was used withcheck_dce
.
Returns:
Nothing. DCE information is printed to stdout.
equinox.debug.assert_max_traces(fn: Callable = sentinel, *, max_traces: Optional[int])
¤
Asserts that the wrapped callable is not called more than max_traces
times.
The typical use-case for this is to check that a JIT-compiled function is not
compiled more than max_traces
times. (I.e. this function can be used to guard
against bugs.) In this case it should be placed within the JIT wrapper.
Example
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(x, y):
return x + y
Arguments:
fn
: The callable to wrap.max_traces
: keyword only argument. The maximum number of calls that are allowed. Can beNone
to allow arbitrarily many calls; in this case the number of calls can still can be found viaequinox.debug.get_num_traces
.
Returns:
A wrapped version of fn
that tracks the number of times it is called. This will
raise a RuntimeError
if is called more than max_traces
many times.
Info
See also
chex.assert_max_traces
which provides similar functionality.
The differences are that (a) Chex's implementation is a bit stricter, as the following will raise:
import chex
import jax
def f(x):
pass
f2 = jax.jit(chex.assert_max_traces(f, 1))
f3 = jax.jit(chex.assert_max_traces(f, 1))
f2(1)
f3(1) # will raise, despite the fact that f2 and f3 are different.
You may prefer the Chex version if you prefer the stricter raising behaviour of the above code.
equinox.debug.get_num_traces(fn) -> int
¤
Given a function wrapped in equinox.debug.assert_max_traces
, return the
number of times which it has been traced so far.