Partitioning and combining¤
These split a PyTree into pieces, or combine several PyTrees back into a single PyTree.
equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with replace
.
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]
Arguments:
pytree
is any PyTree.filter_spec
is a PyTree whose structure should be a prefix of the structure ofpytree
. Each of its leaves should either be:True
, in which case the leaf or subtree is kept;False
, in which case the leaf or subtree is replaced withreplace
;- a callable
Leaf -> bool
, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
inverse
switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.replace
is what to replace any falsey leaves with. Defaults toNone
.is_leaf
: Optional function called at each node of the PyTree. It should return a boolean.True
indicates that the whole subtree should be treated as leaf;False
indicates that the subtree should be traversed as a PyTree.
Returns:
A PyTree of the same structure as pytree
.
equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree]
¤
Splits a PyTree into two pieces. Equivalent to
filter(...), filter(..., inverse=True)
, but slightly more efficient.
Info
See also equinox.combine
to reconstitute the PyTree again.
equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Combines multiple PyTrees into one PyTree, by replacing None
leaves.
Example
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
Tip
The idea is that equinox.combine
should be used to undo a call to
equinox.filter
or equinox.partition
.
Arguments:
*pytrees
: a sequence of PyTrees all with the same structure.is_leaf
: Asequinox.partition
.
Returns:
A PyTree with the same structure as its inputs. Each leaf will be the first
non-None
leaf found in the corresponding leaves of pytrees
as they are
iterated over.
equinox.is_array(element: Any) -> bool
¤
Returns True
if element
is a JAX array or NumPy array.
equinox.is_array_like(element: Any) -> bool
¤
Returns True
if element
is a JAX array, a NumPy array, or a Python
float
/complex
/bool
/int
.
equinox.is_inexact_array(element: Any) -> bool
¤
Returns True
if element
is an inexact (i.e. floating point) JAX/NumPy
array.
equinox.is_inexact_array_like(element: Any) -> bool
¤
Returns True
if element
is an inexact JAX array, an inexact NumPy array, or
a Python float
or complex
.