Filter functions¤
These are some built-in common ways of specifying how to split a PyTree into pieces. (For example, JAX array vs non-JAX-array.)
equinox.is_array(element: Any) -> bool
¤
Returns True
if element
is a JAX array or NumPy array.
equinox.is_array_like(element: Any) -> bool
¤
Returns True
if element
is a JAX array, a NumPy array, or a Python
float
/complex
/bool
/int
.
equinox.is_inexact_array(element: Any) -> bool
¤
Returns True
if element
is an inexact (i.e. floating point) JAX/NumPy array.
equinox.is_inexact_array_like(element: Any) -> bool
¤
Returns True
if element
is an inexact JAX array, an inexact NumPy array, or
a Python float
or complex
.