Manipulating PyTrees¤
equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree
¤
A jtu.tree_map
-broadcasted version of
model = model if update is None else model + update
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as optax.apply_updates
, except
that it understands None
.)
Arguments:
model
: An arbitrary PyTree.updates
: Any PyTree that is a prefix ofmodel
.
Returns:
The updated model.
equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)
¤
Updates a PyTree out-of-place; a bit like using .at[].set()
on a JAX array.
Arguments:
where
: A callablePyTree -> Node
orPyTree -> Sequence[Node]
. It should consume a PyTree with the same structure aspytree
, and return the node or nodes that should be replaced. For examplewhere = lambda mlp: mlp.layers[-1].linear.weight
.pytree
: The PyTree to modify.replace
: Either a single element, or a sequence of the same length as returned bywhere
. This specifies the replacements to make at the locations specified bywhere
. Mutually exclusive withreplace_fn
.replace_fn
: A functionNode -> Any
. It will be called on every node specified bywhere
. The return value fromreplace_fn
will be used in its place. Mutually exclusive withreplace
.is_leaf
: Asjtu.tree_flatten
. For example passis_leaf=lambda x: x is None
to be able to replaceNone
values usingtree_at
.
Returns:
A copy of the input PyTree, with the appropriate modifications.
Example
This can be used to help specify the weights of a model to train or not to train. For example the following will train only the weight of the final linear layer of an MLP:
def loss(model, ...):
...
model = eqx.nn.MLP(...)
trainable = jtu.tree_map(lambda _: False, model)
trainable = equinox.tree_at(lambda mlp: mlp.layers[-1].linear.weight, model, replace=True)
grad_loss = equinox.filter_grad(loss, arg=trainable)
grads = grad_loss(model)
equinox.tree_inference(pytree: PyTree, value: bool) -> PyTree
¤
Convenience function for setting all inference
attributes on a PyTree.
Equivalent to:
has_inference = lambda leaf: hasattr(leaf, "inference")
def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))
equinox.tree_at(where, pytree, replace_fn=lambda _: value)
inference
flags are used to toggle the behaviour of a number of the pre-built
neural network layers, such as equinox.nn.Dropout
or
equinox.experimental.BatchNorm
.
Arguments:
pytree
: the PyTree to modify.value
: the value to set allinference
attributes to.
Returns:
A copy of pytree
with all inference
flags set to value
.