Enumerations¤
equinox.Enumeration
¤
JAX-compatible enums.
Enumerations are instantiated using class syntax, and values looked up on the class:
class RESULTS(eqx.Enumeration):
success = "Hurrah!"
linear_solve_failed = "Linear solve failed to converge"
diffeq_solve_failed = "Differential equation solve exploded horribly"
result = RESULTS.success
Enumerations support equality checking:
x = jnp.where(result == RESULTS.success, a, b)
Enumerations cannot be compared against anything except an Enumeration of the same type:
result == 0 # ValueError at trace time: `0` is an integer, not an Enumeration.
result == SOME_OTHER_ENUMERATION.foo # Likewise, also a ValueError.
Enumerations can be passed through JIT:
jax.jit(lambda x: x)(RESULTS.success)
Enumerations use their assigned value in their repr:
print(RESULTS.success) # RESULTS<Hurrah!>
Given a Enumeration element, just the string can be looked up by indexing it:
result = RESULTS.success
print(RESULTS[result]) # Hurrah!
Enumerations support inheritance, to include all of the superclasses' fields as
well as any new ones. Note that you will need add a # pyright: ignore
wherever
you inherit.
class RESULTS(eqx.Enumeration):
success = "success"
linear_solve_failed = "Linear solve failed to converge"
diffeq_solve_failed = "Differential equation solve exploded horribly"
class MORE_RESULTS(RESULTS): # pyright: ignore
flux_capacitor_overloaded = "Run for your life!"
result = MORE_RESULTS.linear_solve_failed
Enumerations are often used to represent error conditions. As such they have
built-in support for raising runtime errors, via equinox.error_if
:
x = result.error_if(x, pred)
x = eqx.error_if(x, pred, msg)
, where msg
is the
string corresponding to the enumeration item.
promote(item: Enumeration) -> Enumeration
classmethod
¤
Enums support .promote
(on the class) to promote from an inherited
class.
Example
class RESULTS(eqx.Enumeration):
success = "success"
linear_solve_failed = "Linear solve failed to converge"
diffeq_solve_failed = "Differential equation solve exploded horribly"
class MORE_RESULTS(RESULTS): # pyright: ignore
flux_capacitor_overloaded = "Run for your life!"
result == RESULTS.success
# This is a ValueError at trace time
result == MORE_RESULTS.success
# This works. You can only promote from superclasses to subclasses.
result = MORE_RESULTS.promote(result)
result == MORE_RESULTS.success
Arguments:
item
: an item from a parent Enumeration.
Returns:
item
, but as a member of this Enumeration.
where(pred: ArrayLike, a: Enumeration, b: Enumeration) -> Enumeration
classmethod
¤
Enumerations support .where
(on the class), analogous to jnp.where
.
Example
result = RESULTS.where(diff < tol, RESULTS.success, RESULTS.linear_solve_failed)
result = RESULTS.where(step < max_steps, result, RESULTS.diffeq_solve_failed)
Arguments:
pred
: a scalar boolean array.a
: an item of the enumeration. Must be of the same Enumeration as.where
is accessed from.b
: an item of the enumeration. Must be of the same Enumeration as.where
is accessed from.
Returns:
a
if pred
is true, and b
is pred
is false.