Runtime errors¤
Equinox offers support for raising runtime errors.
How does this compare to checkify
in core JAX?
-
checkify
is not compatible with operations likejax.jit
orjax.lax.scan
. (It must be "functionalised" first, usingjax.experimental.checkify.checkify
, and you then need to pipe a handle to the error through your code). In contrast, Equinox's errors will "just work" without any extra effort. -
checkify
stores all errors encountered whilst running your program, and then raises them at the end of the JIT'd region. For example this means that (the JAX equivalent of) the following pseudocode will still end up in the infinite loop (because the end of the computation never arrives):if h < 0: error() while t < t_max: t += h
Meanwhile, Equinox's errors do not wait until the end, so the above computation will have the correct behaviour.
Warning
JAX's support for raising runtime errors is technically only experimental. In practice, this nonetheless seems to be stable enough that these are part of the public API for Equinox.
equinox.error_if(x: PyTree, pred: ArrayLike, msg: str) -> PyTree
¤
Throws an error based on runtime values. Works even under JIT.
Arguments:
x
: will be returned unchanged. This is used to determine where the error check happens in the overall computation: it will happen afterx
is computed and before the return value is used.x
can be any PyTree, and it must contain at least one array.pred
: a boolean for whether to raise an error. Can be an array of bools; an error will be raised if any of them areTrue
. If vmap'd then an error will be raised if any batch element hasTrue
.msg
: the string to display as an error message.
In addition, the EQX_ON_ERROR
environment variable is checked for how any runtime
errors should be handled. Possible values are:
EQX_ON_ERROR=raise
will raise a runtime error.EQX_ON_ERROR=nan
will returnNaN
instead ofx
, and then continue the computation.EQX_ON_ERROR=breakpoint
will open a debugger.- Note that this option may prevent certain compiler optimisations, so permanently fixing this value is not recommended.
- You will need to also pass the
-s
flag topytest
, if you are also using that. - By default this only allows you to see a single frame in the debugger. This is
to work around JAX bug #16732.
(Bugs whilst debugging bugs, eek!) In practice you may like to set the
EQX_ON_ERROR_BREAKPOINT_FRAMES
environment variable to a small integer, which specifies how many frames upwards the debugger should capture. The JAX bug is triggered when taking too many frames.
After changing an environment variable, the Python process must be restarted.
Returns:
The original argument x
unchanged. If this return value is unused then the error
check will not be performed. (It will be removed as part of dead code
elimination.)
Example
@jax.jit
def f(x):
x = error_if(x, x < 0, "x must be >= 0")
# ...use x in your computation...
return x
f(jax.numpy.array(-1))
equinox.branched_error_if(x: PyTree, pred: ArrayLike, index: ArrayLike, msgs: Sequence[str]) -> PyTree
¤
As equinox.error_if
, but will raise one of
several msgs
depending on the value of index
. If index
is vmap'd, then the
error message from the largest value (across the whole batch) will be used.