# Manipulating PyTrees¤

####
`equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree`

¤

A `jtu.tree_map`

-broadcasted version of

```
model = model if update is None else model + update
```

This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as `optax.apply_updates`

, except
that it understands `None`

.)

**Arguments:**

`model`

: An arbitrary PyTree.`updates`

: Any PyTree that is a prefix of`model`

.

**Returns:**

The updated model.

####
`equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)`

¤

Updates a PyTree out-of-place; a bit like using `.at[].set()`

on a JAX array.

**Arguments:**

`where`

: A callable`PyTree -> Node`

or`PyTree -> Sequence[Node]`

. It should consume a PyTree with the same structure as`pytree`

, and return the node or nodes that should be replaced. For example`where = lambda mlp: mlp.layers[-1].linear.weight`

.`pytree`

: The PyTree to modify.`replace`

: Either a single element, or a sequence of the same length as returned by`where`

. This specifies the replacements to make at the locations specified by`where`

. Mutually exclusive with`replace_fn`

.`replace_fn`

: A function`Node -> Any`

. It will be called on every node specified by`where`

. The return value from`replace_fn`

will be used in its place. Mutually exclusive with`replace`

.`is_leaf`

: As`jtu.tree_flatten`

. For example pass`is_leaf=lambda x: x is None`

to be able to replace`None`

values using`tree_at`

.

**Returns:**

A copy of the input PyTree, with the appropriate modifications.

Example

This can be used to help specify the weights of a model to train or not to train. For example the following will train only the weight of the final linear layer of an MLP:

```
def loss(model, ...):
...
model = eqx.nn.MLP(...)
trainable = jtu.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)
grad_loss = equinox.filter_grad(loss, arg=trainable)
grads = grad_loss(model)
```

####
`equinox.tree_inference(pytree: PyTree, value: bool) -> PyTree`

¤

Convenience function for setting all `inference`

attributes on a PyTree.

Equivalent to:

```
has_inference = lambda leaf: hasattr(leaf, "inference")
def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))
equinox.tree_at(where, pytree, replace_fn=lambda _: value)
```

`inference`

flags are used to toggle the behaviour of a number of the pre-built
neural network layers, such as `equinox.nn.Dropout`

or
`equinox.experimental.BatchNorm`

.

**Arguments:**

`pytree`

: the PyTree to modify.`value`

: the value to set all`inference`

attributes to.

**Returns:**

A copy of `pytree`

with all `inference`

flags set to `value`

.