equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree
jtu.tree_map-broadcasted version of
model = model if update is None else model + update
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as
that it understands
model: An arbitrary PyTree.
updates: Any PyTree that is a prefix of
The updated model.
equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)
Updates a PyTree out-of-place; a bit like using
.at.set() on a JAX array.
where: A callable
PyTree -> Nodeor
PyTree -> Sequence[Node]. It should consume a PyTree with the same structure as
pytree, and return the node or nodes that should be replaced. For example
where = lambda mlp: mlp.layers[-1].linear.weight.
pytree: The PyTree to modify.
replace: Either a single element, or a sequence of the same length as returned by
where. This specifies the replacements to make at the locations specified by
where. Mutually exclusive with
replace_fn: A function
Node -> Any. It will be called on every node specified by
where. The return value from
replace_fnwill be used in its place. Mutually exclusive with
jtu.tree_flatten. For example pass
is_leaf=lambda x: x is Noneto be able to replace
A copy of the input PyTree, with the appropriate modifications.
This can be used to help specify the weights of a model to train or not to train. For example the following will train only the weight of the final linear layer of an MLP:
def loss(model, ...): ... model = eqx.nn.MLP(...) trainable = jtu.tree_map(lambda _: False, model) trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True) grad_loss = equinox.filter_grad(loss, arg=trainable) grads = grad_loss(model)
equinox.tree_inference(pytree: PyTree, value: bool) -> PyTree
Convenience function for setting all
inference attributes on a PyTree.
has_inference = lambda leaf: hasattr(leaf, "inference") def where(pytree): return tuple(x.inference for x in jtu.tree_leaves(pytree, is_leaf=has_inference) if has_inference(x)) equinox.tree_at(where, pytree, replace_fn=lambda _: value)
pytree: the PyTree to modify.
value: the value to set all
A copy of
pytree with all
inference flags set to