Equinox offers support for raising runtime errors.
How does this compare to
checkify in core JAX?
checkifyis not compatible with operations like
jax.lax.scan. (It must be "functionalised" first, using
jax.experimental.checkify.checkify, and you then need to pipe a handle to the error through your code). In contrast, Equinox's errors will "just work" without any extra effort.
checkifystores all errors encountered whilst running your program, and then raises them at the end of the JIT'd region. For example this means that (the JAX equivalent of) the following pseudocode will still end up in the infinite loop (because the end of the computation never arrives):
if h < 0: error() while t < t_max: t += h
Meanwhile, Equinox's errors do not wait until the end, so the above computation will have the correct behaviour.
JAX's support for raising runtime errors is technically only experimental. In practice, this nonetheless seems to be stable enough that these are part of the public API for Equinox.
equinox.error_if(x: PyTree, pred: ArrayLike, msg: str) -> PyTree
Throws an error based on runtime values. Works even under JIT.
x: will be returned unchanged. This is used to determine where the error check happens in the overall computation: it will happen after
xis computed and before the return value is used.
xcan be any PyTree, and it must contain at least one array.
pred: a boolean for whether to raise an error. Can be an array of bools; an error will be raised if any of them are
True. If vmap'd then an error will be raised if any batch element has
msg: the string to display as an error message.
In addition, the
EQX_ON_ERROR environment variable is checked for how any runtime
errors should be handled. Possible values are:
EQX_ON_ERROR=raisewill raise a runtime error.
x, and then continue the computation.
EQX_ON_ERROR=breakpointwill open a debugger.
- Note that this option may prevent certain compiler optimisations, so permanently fixing this value is not recommended.
- You will need to also pass the
pytest, if you are also using that.
- This will sometimes raise a trace-time error due to JAX bug
#16732. (Bugs whilst debugging
bugs, eek!) If this happens, then it can be worked around by additionally
EQX_ON_ERROR_BREAKPOINT_FRAMESvariable to a small integer, which specifies how many frames upwards the debugger should capture. The JAX bug is triggered when taking too many frames.
After changing an environment variable, the Python process must be restarted.
The original argument
x unchanged. If this return value is unused then the error
check will not be performed. (It will be removed as part of dead code
@jax.jit def f(x): x = error_if(x, x < 0, "x must be >= 0") # ...use x in your computation... return x f(jax.numpy.array(-1))