Skip to content

Runtime errors¤

Equinox offers support for raising runtime errors.

How does this compare to checkify in core JAX?

  1. checkify is not compatible with operations like jax.jit or jax.lax.scan. (It must be "functionalised" first, using jax.experimental.checkify.checkify, and you then need to pipe a handle to the error through your code). In contrast, Equinox's errors will "just work" without any extra effort.

  2. checkify stores all errors encountered whilst running your program, and then raises them at the end of the JIT'd region. For example this means that (the JAX equivalent of) the following pseudocode will still end up in the infinite loop (because the end of the computation never arrives):

    if h < 0:
        error()
    while t < t_max:
        t += h
    

    Meanwhile, Equinox's errors do not wait until the end, so the above computation will have the correct behaviour.

Warning

JAX's support for raising runtime errors is technically only experimental. In practice, this nonetheless seems to be stable enough that these are part of the public API for Equinox.

equinox.error_if(x: PyTree, pred: ArrayLike, msg: str) -> PyTree ¤

Throws an error based on runtime values. Works even under JIT.

Arguments:

  • x: will be returned unchanged. This is used to determine where the error check happens in the overall computation: it will happen after x is computed and before the return value is used. x can be any PyTree, and it must contain at least one array.
  • pred: a boolean for whether to raise an error. Can be an array of bools; an error will be raised if any of them are True. If vmap'd then an error will be raised if any batch element has True.
  • msg: the string to display as an error message.

In addition, the EQX_ON_ERROR environment variable is checked for how any runtime errors should be handled. Possible values are:

  • EQX_ON_ERROR=raise will raise a runtime error.
  • EQX_ON_ERROR=nan will return NaN instead of x, and then continue the computation.
  • EQX_ON_ERROR=breakpoint will open a debugger.
    • Note that this option may prevent certain compiler optimisations, so permanently fixing this value is not recommended.
    • You will need to also pass the -s flag to pytest, if you are also using that.
    • By default this only allows you to see a single frame in the debugger. This is to work around JAX bug #16732. (Bugs whilst debugging bugs, eek!) In practice you may like to set the EQX_ON_ERROR_BREAKPOINT_FRAMES environment variable to a small integer, which specifies how many frames upwards the debugger should capture. The JAX bug is triggered when taking too many frames.

After changing an environment variable, the Python process must be restarted.

Returns:

The original argument x unchanged. If this return value is unused then the error check will not be performed. (It will be removed as part of dead code elimination.)

Example

@jax.jit
def f(x):
    x = error_if(x, x < 0, "x must be >= 0")
    # ...use x in your computation...
    return x

f(jax.numpy.array(-1))

equinox.branched_error_if(x: PyTree, pred: ArrayLike, index: ArrayLike, msgs: Sequence[str]) -> PyTree ¤

As equinox.error_if, but will raise one of several msgs depending on the value of index. If index is vmap'd, then the error message from the largest value (across the whole batch) will be used.