Skip to content

Manipulating PyTrees¤

equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree ¤

A jtu.tree_map-broadcasted version of

model = model if update is None else model + update

This is often useful when updating a model's parameters via stochastic gradient descent. (This function is essentially the same as optax.apply_updates, except that it understands None.)

Arguments:

  • model: An arbitrary PyTree.
  • updates: Any PyTree that is a prefix of model.

Returns:

The updated model.


equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None) ¤

Updates a PyTree out-of-place; a bit like using .at[].set() on a JAX array.

Arguments:

  • where: A callable PyTree -> Node or PyTree -> Sequence[Node]. It should consume a PyTree with the same structure as pytree, and return the node or nodes that should be replaced. For example where = lambda mlp: mlp.layers[-1].linear.weight.
  • pytree: The PyTree to modify.
  • replace: Either a single element, or a sequence of the same length as returned by where. This specifies the replacements to make at the locations specified by where. Mutually exclusive with replace_fn.
  • replace_fn: A function Node -> Any. It will be called on every node specified by where. The return value from replace_fn will be used in its place. Mutually exclusive with replace.
  • is_leaf: As jtu.tree_flatten. For example pass is_leaf=lambda x: x is None to be able to replace None values using tree_at.

Returns:

A copy of the input PyTree, with the appropriate modifications.

Example

This can be used to help specify the weights of a model to train or not to train. For example the following will train only the weight of the final linear layer of an MLP:

def loss(model, ...):
    ...

model = eqx.nn.MLP(...)
trainable = jtu.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)
grad_loss = equinox.filter_grad(loss, arg=trainable)
grads = grad_loss(model)

equinox.tree_inference(pytree: PyTree, value: bool) -> PyTree ¤

Convenience function for setting all inference attributes on a PyTree.

Equivalent to:

has_inference = lambda leaf: hasattr(leaf, "inference")

def where(pytree):
    return tuple(x.inference
                 for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
                 if has_inference(x))

equinox.tree_at(where, pytree, replace_fn=lambda _: value)

inference flags are used to toggle the behaviour of a number of the pre-built neural network layers, such as equinox.nn.Dropout or equinox.experimental.BatchNorm.

Arguments:

  • pytree: the PyTree to modify.
  • value: the value to set all inference attributes to.

Returns:

A copy of pytree with all inference flags set to value.