Skip to content

Partitioning and combining¤

These split a PyTree into pieces, or combine several PyTrees back into a single PyTree.

equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying the condition are replaced with replace.

Example

pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]

Example

pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]

Arguments:

  • pytree is any PyTree.
  • filter_spec is a PyTree whose structure should be a prefix of the structure of pytree. Each of its leaves should either be:
    • True, in which case the leaf or subtree is kept;
    • False, in which case the leaf or subtree is replaced with replace;
    • a callable Leaf -> bool, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
  • inverse switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.
  • replace is what to replace any falsey leaves with. Defaults to None.
  • is_leaf: Optional function called at each node of the PyTree. It should return a boolean. True indicates that the whole subtree should be treated as leaf; False indicates that the subtree should be traversed as a PyTree.

Returns:

A PyTree of the same structure as pytree.


equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree] ¤

Splits a PyTree into two pieces. Equivalent to filter(...), filter(..., inverse=True), but slightly more efficient.

Info

See also equinox.combine to reconstitute the PyTree again.


equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Combines multiple PyTrees into one PyTree, by replacing None leaves.

Example

pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2)  # [0, 1, 2]

Tip

The idea is that equinox.combine should be used to undo a call to equinox.filter or equinox.partition.

Arguments:

  • *pytrees: a sequence of PyTrees all with the same structure.
  • is_leaf: As equinox.partition.

Returns:

A PyTree with the same structure as its inputs. Each leaf will be the first non-None leaf found in the corresponding leaves of pytrees as they are iterated over.


equinox.is_array(element: Any) -> bool ¤

Returns True if element is a JAX array or NumPy array.


equinox.is_array_like(element: Any) -> bool ¤

Returns True if element is a JAX array, a NumPy array, or a Python float/complex/bool/int.


equinox.is_inexact_array(element: Any) -> bool ¤

Returns True if element is an inexact (i.e. floating point) JAX/NumPy array.


equinox.is_inexact_array_like(element: Any) -> bool ¤

Returns True if element is an inexact JAX array, an inexact NumPy array, or a Python float or complex.