Manipulating PyTrees¤
Tip
equinox.tree_at
is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.
Info
In addition to these utilities, also see equinox.partition
, equinox.combine
, and equinox.filter
.
equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree
¤
A jtu.tree_map
-broadcasted version of
if update is None:
return model
else:
return model + update
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as optax.apply_updates
, except
that it understands None
.) For example see the
Train RNN example.
Arguments:
model
: An arbitrary PyTree.updates
: Any PyTree that is a prefix ofmodel
.
Returns:
The updated model.
equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)
¤
Modifies a PyTree out-of-place. (A bit like using .at[].set()
on a JAX array.)
Arguments:
where
: A callablePyTree -> Node
orPyTree -> tuple[Node, ...]
. It should consume a PyTree with the same structure aspytree
, and return the node or nodes that should be replaced. For examplewhere = lambda mlp: mlp.layers[-1].linear.weight
.pytree
: The PyTree to modify.replace
: Either a single element, or a sequence of the same length as returned bywhere
. This specifies the replacements to make at the locations specified bywhere
. Mutually exclusive withreplace_fn
.replace_fn
: A functionNode -> Any
. It will be called on every node specified bywhere
. The return value fromreplace_fn
will be used in its place. Mutually exclusive withreplace
.is_leaf
: Asjtu.tree_flatten
. For example passis_leaf=lambda x: x is None
to be able to replaceNone
values usingtree_at
.
Note that where
should not depend on the type of any of the leaves of the
pytree, e.g. given pytree = [1, 2, object(), 3]
, then
where = lambda x: tuple(xi for xi in x if type(xi) is int)
is not allowed. If you
really need this behaviour then this example could instead be expressed as
where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int)
.
Returns:
A copy of the input PyTree, with the appropriate modifications.
Example
# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.
Example
This is useful for performing model surgery. For example:
mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array]
¤
Returns True
if all input PyTrees are equal. Every PyTree must have the same
structure, and all leaves must be equal.
- For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and
dtype.
- If
rtol=0
andatol=0
(the default) then all their values must be equal. Otherwise they must satisfy{j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol)
. - If
typematch=False
(the default) then JAX and NumPy arrays are considered equal to each other. Iftypematch=True
then JAX and NumPy are not considered equal to each other.
- If
- For non-arrays, if
typematch=False
(the default) then equality is determined with justleaf1 == leaf2
. Iftypematch=True
thentype(leaf1) == type(leaf2)
is also required.
If used under JIT, and any JAX arrays are present, then this may return a tracer.
Use the idiom result = tree_equal(...) is True
if you'd like to assert that the
result is statically true without dependence on the value of any traced arrays.
Arguments:
*pytrees
: Any number of PyTrees each with any structure.typematch:
Whether to additionally require that corresponding leaves should have the sametype(...)
as each other.rtol
: Used to determine thertol
ofjnp.allclose
/np.allclose
when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.atol
: Asrtol
.
Returns:
A boolean, or bool-typed tracer.
equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef]
¤
Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.
Arguments:
pytree
: the PyTree node to flatten by one level.
Returns:
As jax.tree_util.tree_flatten
: a list of leaves and a PyTreeDef
.
Example
x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))
y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))
equinox.tree_check(pytree: Any) -> None
¤
Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers.
Precisely, a "duplicate layer" is any PyTree node with at least one child node.
Info
This is automatically called when creating an eqx.Module
instance, to help
avoid bugs from duplicating layers.
Example
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is nontrivial!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is trivial
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is trivial
eqx.tree_check([None, None]) # passes, duplicate is trivial
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is nontrivial!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
Arguments:
pytree
: the PyTree to check.
Returns:
Nothing.
Raises:
A ValueError
if the PyTree is not well-formed.
equinox.Partial (Module)
¤
Like functools.partial
, but treats the wrapped function, and partially-applied
value: Any = field(static=True)
args and kwargs, as a PyTree.
This is very much like jax.tree_util.Partial
. The difference is that the JAX
version requires that func
be specifically a function -- and will silently
misbehave if given any non-function callable, e.g. equinox.nn.MLP
. In contrast
the Equinox version allows for arbitrary callables.
__init__(/, self, func, *args, **kwargs)
¤
Arguments:
func
: the callable to partially apply.*args
: any positional arguments to apply.**kwargs
: any keyword arguments to apply.
__call__(self, *args, **kwargs)
¤
Call the wrapped self.func
.
Arguments:
*args
: the arguments to apply. Passed after those arguments passed during__init__
.**kwargs
: any keyword arguments to apply.
Returns:
The result of the wrapped function.