# Manipulating PyTrees¤

Tip

`equinox.tree_at`

is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.

## Filtering, partitioning, combining¤

####
`equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree`

¤

Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with `replace`

.

Example

```
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]
```

Example

```
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]
```

**Arguments:**

`pytree`

is any PyTree.`filter_spec`

is a PyTree whose structure should be a prefix of the structure of`pytree`

. Each of its leaves should either be:`True`

, in which case the leaf or subtree is kept;`False`

, in which case the leaf or subtree is replaced with`replace`

;- a callable
`Leaf -> bool`

, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.

`inverse`

switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.`replace`

is what to replace any falsey leaves with. Defaults to`None`

.`is_leaf`

: Optional function called at each node of the PyTree. It should return a boolean.`True`

indicates that the whole subtree should be treated as leaf;`False`

indicates that the subtree should be traversed as a PyTree.

**Returns:**

A PyTree of the same structure as `pytree`

.

####
`equinox.is_array(element: Any) -> bool`

¤

Returns `True`

if `element`

is a JAX array or NumPy array.

####
`equinox.is_array_like(element: Any) -> bool`

¤

Returns `True`

if `element`

is a JAX array, a NumPy array, or a Python
`float`

/`complex`

/`bool`

/`int`

.

####
`equinox.is_inexact_array(element: Any) -> bool`

¤

Returns `True`

if `element`

is an inexact (i.e. floating or complex) JAX/NumPy
array.

####
`equinox.is_inexact_array_like(element: Any) -> bool`

¤

Returns `True`

if `element`

is an inexact JAX array, an inexact NumPy array, or
a Python `float`

or `complex`

.

####
`equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree]`

¤

Splits a PyTree into two pieces. Equivalent to
`filter(...), filter(..., inverse=True)`

, but slightly more efficient.

Info

See also `equinox.combine`

to reconstitute the PyTree again.

####
`equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree`

¤

Combines multiple PyTrees into one PyTree, by replacing `None`

leaves.

Example

```
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
```

Tip

The idea is that `equinox.combine`

should be used to undo a call to
`equinox.filter`

or `equinox.partition`

.

**Arguments:**

`*pytrees`

: a sequence of PyTrees all with the same structure.`is_leaf`

: As`equinox.partition`

.

**Returns:**

A PyTree with the same structure as its inputs. Each leaf will be the first
non-`None`

leaf found in the corresponding leaves of `pytrees`

as they are
iterated over.

## Common operations on PyTrees¤

####
`equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree`

¤

A `jax.tree_util.tree_map`

-broadcasted version of

```
if update is None:
return model
else:
return model + update
```

This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as `optax.apply_updates`

, except
that it understands `None`

.) For example see the
Train RNN example.

**Arguments:**

`model`

: An arbitrary PyTree.`updates`

: Any PyTree that is a prefix of`model`

.

**Returns:**

The updated model.

####
`equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)`

¤

Modifies a leaf or subtree of a PyTree. (A bit like using `.at[].set()`

on a JAX
array.)

The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!

**Arguments:**

`where`

: A callable`PyTree -> Node`

or`PyTree -> tuple[Node, ...]`

. It should consume a PyTree with the same structure as`pytree`

, and return the node or nodes that should be replaced. For example`where = lambda mlp: mlp.layers[-1].linear.weight`

.`pytree`

: The PyTree to modify.`replace`

: Either a single element, or a sequence of the same length as returned by`where`

. This specifies the replacements to make at the locations specified by`where`

. Mutually exclusive with`replace_fn`

.`replace_fn`

: A function`Node -> Any`

. It will be called on every node specified by`where`

. The return value from`replace_fn`

will be used in its place. Mutually exclusive with`replace`

.`is_leaf`

: As`jtu.tree_flatten`

. For example pass`is_leaf=lambda x: x is None`

to be able to replace`None`

values using`tree_at`

.

Note that `where`

should not depend on the type of any of the leaves of the
pytree, e.g. given `pytree = [1, 2, object(), 3]`

, then
`where = lambda x: tuple(xi for xi in x if type(xi) is int)`

is not allowed. If you
really need this behaviour then this example could instead be expressed as
`where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int)`

.

**Returns:**

A copy of the input PyTree, with the appropriate modifications.

Example

```
# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.
```

Example

This is useful for performing model surgery. For example:

```
mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
```

Info

Constructing analogous PyTrees, with the same structure but different leaves, is
very common in JAX: for example when constructing the `in_axes`

argument to
`jax.vmap`

.

To support this use-case, the returned PyTree is constructed without calling
`__init__`

, `__post_init__`

, or
`__check_init__`

. This allows
for modifying leaves to be anything, regardless of the use of any custom
constructor or custom checks in the original PyTree.

####
`equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array]`

¤

Returns `True`

if all input PyTrees are equal. Every PyTree must have the same
structure, and all leaves must be equal.

- For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and
dtype.
- If
`rtol=0`

and`atol=0`

(the default) then all their values must be equal. Otherwise they must satisfy`{j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol)`

. - If
`typematch=False`

(the default) then JAX and NumPy arrays are considered equal to each other. If`typematch=True`

then JAX and NumPy are not considered equal to each other.

- If
- For non-arrays, if
`typematch=False`

(the default) then equality is determined with just`leaf1 == leaf2`

. If`typematch=True`

then`type(leaf1) == type(leaf2)`

is also required.

If used under JIT, and any JAX arrays are present, then this may return a tracer.
Use the idiom `result = tree_equal(...) is True`

if you'd like to assert that the
result is statically true without dependence on the value of any traced arrays.

**Arguments:**

`*pytrees`

: Any number of PyTrees each with any structure.`typematch:`

Whether to additionally require that corresponding leaves should have the same`type(...)`

as each other.`rtol`

: Used to determine the`rtol`

of`jnp.allclose`

/`np.allclose`

when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.`atol`

: As`rtol`

.

**Returns:**

A boolean, or bool-typed tracer.

####
```
equinox.Partial (Module)
```

¤

Like `functools.partial`

, but treats the wrapped function, and partially-applied
value: Any = field(static=True)
args and kwargs, as a PyTree.

This is very much like `jax.tree_util.Partial`

. The difference is that the JAX
version requires that `func`

be specifically a *function* -- and will silently
misbehave if given any non-function callable, e.g. `equinox.nn.MLP`

. In contrast
the Equinox version allows for arbitrary callables.

#####
`__init__(/, self, func, *args, **kwargs)`

¤

**Arguments:**

`func`

: the callable to partially apply.`*args`

: any positional arguments to apply.`**kwargs`

: any keyword arguments to apply.

#####
`__call__(self, *args, **kwargs)`

¤

Call the wrapped `self.func`

.

**Arguments:**

`*args`

: the arguments to apply. Passed after those arguments passed during`__init__`

.`**kwargs`

: any keyword arguments to apply.

**Returns:**

The result of the wrapped function.

## Unusual operations on PyTrees¤

####
`equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef]`

¤

Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.

**Arguments:**

`pytree`

: the PyTree node to flatten by one level.

**Returns:**

As `jax.tree_util.tree_flatten`

: a list of leaves and a `PyTreeDef`

.

Example

```
x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))
y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))
```

####
`equinox.tree_check(pytree: Any) -> None`

¤

Checks if the PyTree has no self-references, and if all non-leaf nodes are unique
Python objects. (For example something like `x = [1]; y = [x, x]`

would fail as `x`

appears twice in the PyTree.)

Having unique non-leaf nodes isn't actually a requirement that JAX imposes, but it
will become true after passing through an operation like `jax.{jit, grad, ...}`

(as
JAX copies the PyTree without trying to preserve identity). As such some users like
to use this function to assert that this invariant was already true prior to the
transform, as a way to avoid surprises.

Example

```
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is non-leaf!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is leaf
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is leaf
eqx.tree_check([None, None]) # passes, duplicate is leaf
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is non-leaf!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
```

**Arguments:**

`pytree`

: the PyTree to check.

**Returns:**

Nothing.

**Raises:**

A `ValueError`

if the PyTree is not well-formed.