Skip to content

Enumerations¤

equinox.Enumeration ¤

JAX-compatible enums.

Enumerations are instantiated using class syntax, and values looked up on the class:

class RESULTS(eqx.Enumeration):
    success = "Hurrah!"
    linear_solve_failed = "Linear solve failed to converge"
    diffeq_solve_failed = "Differential equation solve exploded horribly"

result = RESULTS.success

Enumerations support equality checking:

x = jnp.where(result == RESULTS.success, a, b)

Enumerations cannot be compared against anything except an Enumeration of the same type:

result == 0  # ValueError at trace time: `0` is an integer, not an Enumeration.
result == SOME_OTHER_ENUMERATION.foo  # Likewise, also a ValueError.

Enumerations can be passed through JIT:

jax.jit(lambda x: x)(RESULTS.success)

Enumerations use their assigned value in their repr:

print(RESULTS.success)  # RESULTS<Hurrah!>

Given a Enumeration element, just the string can be looked up by indexing it:

result = RESULTS.success
print(RESULTS[result])  # Hurrah!

Enumerations support inheritance, to include all of the superclasses' fields as well as any new ones. Note that you will need add a # pyright: ignore wherever you inherit.

class RESULTS(eqx.Enumeration):
    success = "success"
    linear_solve_failed = "Linear solve failed to converge"
    diffeq_solve_failed = "Differential equation solve exploded horribly"

class MORE_RESULTS(RESULTS):  # pyright: ignore
    flux_capacitor_overloaded = "Run for your life!"

result = MORE_RESULTS.linear_solve_failed

Enumerations are often used to represent error conditions. As such they have built-in support for raising runtime errors, via equinox.error_if:

x = result.error_if(x, pred)
this is equivalent to x = eqx.error_if(x, pred, msg), where msg is the string corresponding to the enumeration item.

promote(item: Enumeration) -> Enumeration classmethod ¤

Enums support .promote (on the class) to promote from an inherited class.

Example

class RESULTS(eqx.Enumeration):
    success = "success"
    linear_solve_failed = "Linear solve failed to converge"
    diffeq_solve_failed = "Differential equation solve exploded horribly"

class MORE_RESULTS(RESULTS):  # pyright: ignore
    flux_capacitor_overloaded = "Run for your life!"

result == RESULTS.success

# This is a ValueError at trace time
result == MORE_RESULTS.success

# This works. You can only promote from superclasses to subclasses.
result = MORE_RESULTS.promote(result)
result == MORE_RESULTS.success

Arguments:

  • item: an item from a parent Enumeration.

Returns:

item, but as a member of this Enumeration.

where(pred: ArrayLike, a: Enumeration, b: Enumeration) -> Enumeration classmethod ¤

Enumerations support .where (on the class), analogous to jnp.where.

Example

result = RESULTS.where(diff < tol, RESULTS.success, RESULTS.linear_solve_failed)
result = RESULTS.where(step < max_steps, result, RESULTS.diffeq_solve_failed)

Arguments:

  • pred: a scalar boolean array.
  • a: an item of the enumeration. Must be of the same Enumeration as .where is accessed from.
  • b: an item of the enumeration. Must be of the same Enumeration as .where is accessed from.

Returns:

a if pred is true, and b is pred is false.