# Manipulating PyTrees¤

#### equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree¤

A jtu.tree_map-broadcasted version of

model = model if update is None else model + update


This is often useful when updating a model's parameters via stochastic gradient descent. (This function is essentially the same as optax.apply_updates, except that it understands None.)

Arguments:

• model: An arbitrary PyTree.
• updates: Any PyTree that is a prefix of model.

Returns:

The updated model.

#### equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)¤

Updates a PyTree out-of-place; a bit like using .at[].set() on a JAX array.

Arguments:

• where: A callable PyTree -> Node or PyTree -> Sequence[Node]. It should consume a PyTree with the same structure as pytree, and return the node or nodes that should be replaced. For example where = lambda mlp: mlp.layers[-1].linear.weight.
• pytree: The PyTree to modify.
• replace: Either a single element, or a sequence of the same length as returned by where. This specifies the replacements to make at the locations specified by where. Mutually exclusive with replace_fn.
• replace_fn: A function Node -> Any. It will be called on every node specified by where. The return value from replace_fn will be used in its place. Mutually exclusive with replace.
• is_leaf: As jtu.tree_flatten. For example pass is_leaf=lambda x: x is None to be able to replace None values using tree_at.

Returns:

A copy of the input PyTree, with the appropriate modifications.

Example

This can be used to help specify the weights of a model to train or not to train. For example the following will train only the weight of the final linear layer of an MLP:

def loss(model, ...):
...

model = eqx.nn.MLP(...)
trainable = jtu.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)


#### equinox.tree_inference(pytree: PyTree, value: bool) -> PyTree¤

Convenience function for setting all inference attributes on a PyTree.

Equivalent to:

has_inference = lambda leaf: hasattr(leaf, "inference")

def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))

equinox.tree_at(where, pytree, replace_fn=lambda _: value)


inference flags are used to toggle the behaviour of a number of the pre-built neural network layers, such as equinox.nn.Dropout or equinox.experimental.BatchNorm.

Arguments:

• pytree: the PyTree to modify.
• value: the value to set all inference attributes to.

Returns:

A copy of pytree with all inference flags set to value.