Filter, partition, combine¤
These are the functions that actually split a PyTree into pieces, or combine several PyTrees back into a PyTree.
equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, typing.Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with
pytreeis any PyTree.
filter_specis a PyTree whose structure should be a prefix of the structure of
pytree. Each of its leaves should either be:
True, in which case the leaf or subtree is kept;
False, in which case the leaf or subtree is replaced with
- a callable
Leaf -> bool, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
inverseswitches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.
replaceis what to replace any falsey leaves with. Defaults to
is_leaf: Optional function called at each node of the PyTree. It should return a boolean.
Trueindicates that the whole subtree should be treated as leaf;
Falseindicates that the subtree should be traversed as a PyTree. This is mostly useful for evaluating a callable
filter_specon a node instead of a leaf.
A PyTree of the same structure as
A common special case is
equinox.filter(pytree, equinox.is_array). Then
equinox.is_array is evaluted on all of
pytree's leaves, and each leaf then
kept or replaced.
equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, typing.Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
filter(...), filter(..., inverse=True), but slightly more
equinox.combine to reconstitute the PyTree again.
equinox.combine(*pytrees: PyTree) -> PyTree
Combines multiple PyTrees into one PyTree, by replacing
pytree1 = [None, 1, 2] pytree2 = [0, None, None] equinox.combine(pytree1, pytree2) # [0, 1, 2]
*pytrees: a sequence of PyTrees all with the same structure.
A PyTree with the same structure as its inputs. Each leaf will be the first
None leaf found in the corresponding leaves of
pytrees as they are