Skip to content

Filter, partition, combine¤

These are the functions that actually split a PyTree into pieces, or combine several PyTrees back into a PyTree.

equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, typing.Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying the condition are replaced with replace.

Arguments:

  • pytree is any PyTree.
  • filter_spec is a PyTree whose structure should be a prefix of the structure of pytree. Each of its leaves should either be:
    • True, in which case the leaf or subtree is kept;
    • False, in which case the leaf or subtree is replaced with replace;
    • a callable Leaf -> bool, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
  • inverse switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.
  • replace is what to replace any falsey leaves with. Defaults to None.
  • is_leaf: Optional function called at each node of the PyTree. It should return a boolean. True indicates that the whole subtree should be treated as leaf; False indicates that the subtree should be traversed as a PyTree. This is mostly useful for evaluating a callable filter_spec on a node instead of a leaf.

Returns:

A PyTree of the same structure as pytree.

Info

A common special case is equinox.filter(pytree, equinox.is_array). Then equinox.is_array is evaluted on all of pytree's leaves, and each leaf then kept or replaced.


equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, typing.Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Equivalent to filter(...), filter(..., inverse=True), but slightly more efficient.

Info

See also equinox.combine to reconstitute the PyTree again.


equinox.combine(*pytrees: PyTree) -> PyTree ¤

Combines multiple PyTrees into one PyTree, by replacing None leaves.

Example

pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2)  # [0, 1, 2]

Tip

The idea is that equinox.combine should be used to undo a call to equinox.filter or equinox.partition.

Arguments:

  • *pytrees: a sequence of PyTrees all with the same structure.

Returns:

A PyTree with the same structure as its inputs. Each leaf will be the first non-None leaf found in the corresponding leaves of pytrees as they are iterated over.