# Manipulating PyTrees¤

Tip

`equinox.tree_at`

is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.

## Filtering, partitioning, combining¤

####
`equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree`

¤

Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with `replace`

.

Example

```
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]
```

Example

```
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]
```

**Arguments:**

`pytree`

is any PyTree.`filter_spec`

is a PyTree whose structure should be a prefix of the structure of`pytree`

. Each of its leaves should either be:`True`

, in which case the leaf or subtree is kept;`False`

, in which case the leaf or subtree is replaced with`replace`

;- a callable
`Leaf -> bool`

, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.

`inverse`

switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.`replace`

is what to replace any falsey leaves with. Defaults to`None`

.`is_leaf`

: Optional function called at each node of the PyTree. It should return a boolean.`True`

indicates that the whole subtree should be treated as leaf;`False`

indicates that the subtree should be traversed as a PyTree.

**Returns:**

A PyTree of the same structure as `pytree`

.

####
`equinox.is_array(element: Any) -> bool`

¤

Returns `True`

if `element`

is a JAX array or NumPy array.

####
`equinox.is_array_like(element: Any) -> bool`

¤

Returns `True`

if `element`

is a JAX array, a NumPy array, or a Python
`float`

/`complex`

/`bool`

/`int`

.

####
`equinox.is_inexact_array(element: Any) -> bool`

¤

Returns `True`

if `element`

is an inexact (i.e. floating or complex) JAX/NumPy
array.

####
`equinox.is_inexact_array_like(element: Any) -> bool`

¤

Returns `True`

if `element`

is an inexact JAX array, an inexact NumPy array, or
a Python `float`

or `complex`

.

####
`equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree]`

¤

Splits a PyTree into two pieces. Equivalent to
`filter(...), filter(..., inverse=True)`

, but slightly more efficient.

Info

See also `equinox.combine`

to reconstitute the PyTree again.

####
`equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree`

¤

Combines multiple PyTrees into one PyTree, by replacing `None`

leaves.

Example

```
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
```

Tip

The idea is that `equinox.combine`

should be used to undo a call to
`equinox.filter`

or `equinox.partition`

.

**Arguments:**

`*pytrees`

: a sequence of PyTrees all with the same structure.`is_leaf`

: As`equinox.partition`

.

**Returns:**

A PyTree with the same structure as its inputs. Each leaf will be the first
non-`None`

leaf found in the corresponding leaves of `pytrees`

as they are
iterated over.

## Common operations on PyTrees¤

####
`equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree`

¤

A `jax.tree_util.tree_map`

-broadcasted version of

```
if update is None:
return model
else:
return model + update
```

This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as `optax.apply_updates`

, except
that it understands `None`

.) For example see the
Train RNN example.

**Arguments:**

`model`

: An arbitrary PyTree.`updates`

: Any PyTree that is a prefix of`model`

.

**Returns:**

The updated model.

####
`equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)`

¤

Modifies a leaf or subtree of a PyTree. (A bit like using `.at[].set()`

on a JAX
array.)

The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!

**Arguments:**

`where`

: A callable`PyTree -> Node`

or`PyTree -> tuple[Node, ...]`

. It should consume a PyTree with the same structure as`pytree`

, and return the node or nodes that should be replaced. For example`where = lambda mlp: mlp.layers[-1].linear.weight`

.`pytree`

: The PyTree to modify.`replace`

: Either a single element, or a sequence of the same length as returned by`where`

. This specifies the replacements to make at the locations specified by`where`

. Mutually exclusive with`replace_fn`

.`replace_fn`

: A function`Node -> Any`

. It will be called on every node specified by`where`

. The return value from`replace_fn`

will be used in its place. Mutually exclusive with`replace`

.`is_leaf`

: As`jtu.tree_flatten`

. For example pass`is_leaf=lambda x: x is None`

to be able to replace`None`

values using`tree_at`

.

Note that `where`

should not depend on the type of any of the leaves of the
pytree, e.g. given `pytree = [1, 2, object(), 3]`

, then
`where = lambda x: tuple(xi for xi in x if type(xi) is int)`

is not allowed. If you
really need this behaviour then this example could instead be expressed as
`where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int)`

.

**Returns:**

A copy of the input PyTree, with the appropriate modifications.

Example

```
# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.
```

Example

This is useful for performing model surgery. For example:

```
mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
```

####
`equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array]`

¤

Returns `True`

if all input PyTrees are equal. Every PyTree must have the same
structure, and all leaves must be equal.

- For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and
dtype.
- If
`rtol=0`

and`atol=0`

(the default) then all their values must be equal. Otherwise they must satisfy`{j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol)`

. - If
`typematch=False`

(the default) then JAX and NumPy arrays are considered equal to each other. If`typematch=True`

then JAX and NumPy are not considered equal to each other.

- If
- For non-arrays, if
`typematch=False`

(the default) then equality is determined with just`leaf1 == leaf2`

. If`typematch=True`

then`type(leaf1) == type(leaf2)`

is also required.

If used under JIT, and any JAX arrays are present, then this may return a tracer.
Use the idiom `result = tree_equal(...) is True`

if you'd like to assert that the
result is statically true without dependence on the value of any traced arrays.

**Arguments:**

`*pytrees`

: Any number of PyTrees each with any structure.`typematch:`

Whether to additionally require that corresponding leaves should have the same`type(...)`

as each other.`rtol`

: Used to determine the`rtol`

of`jnp.allclose`

/`np.allclose`

when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.`atol`

: As`rtol`

.

**Returns:**

A boolean, or bool-typed tracer.

####
```
equinox.Partial (Module)
```

¤

Like `functools.partial`

, but treats the wrapped function, and partially-applied
value: Any = field(static=True)
args and kwargs, as a PyTree.

This is very much like `jax.tree_util.Partial`

. The difference is that the JAX
version requires that `func`

be specifically a *function* -- and will silently
misbehave if given any non-function callable, e.g. `equinox.nn.MLP`

. In contrast
the Equinox version allows for arbitrary callables.

#####
`__init__(/, self, func, *args, **kwargs)`

¤

**Arguments:**

`func`

: the callable to partially apply.`*args`

: any positional arguments to apply.`**kwargs`

: any keyword arguments to apply.

#####
`__call__(self, *args, **kwargs)`

¤

Call the wrapped `self.func`

.

**Arguments:**

`*args`

: the arguments to apply. Passed after those arguments passed during`__init__`

.`**kwargs`

: any keyword arguments to apply.

**Returns:**

The result of the wrapped function.

## Unusual operations on PyTrees¤

####
`equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef]`

¤

Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.

**Arguments:**

`pytree`

: the PyTree node to flatten by one level.

**Returns:**

As `jax.tree_util.tree_flatten`

: a list of leaves and a `PyTreeDef`

.

Example

```
x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))
y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))
```

####
`equinox.tree_check(pytree: Any) -> None`

¤

Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers.

Precisely, a "duplicate layer" is any PyTree node with at least one child node.

Info

This is automatically called when creating an `eqx.Module`

instance, to help
avoid bugs from duplicating layers.

Example

```
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is nontrivial!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is trivial
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is trivial
eqx.tree_check([None, None]) # passes, duplicate is trivial
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is nontrivial!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
```

**Arguments:**

`pytree`

: the PyTree to check.

**Returns:**

Nothing.

**Raises:**

A `ValueError`

if the PyTree is not well-formed.