Manipulating PyTrees¤
Tip
equinox.tree_at is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.
Filtering, partitioning, combining¤
equinox.filter(pytree: PyTree, filter_spec: PyTree[bool | Callable[[Any], bool]], inverse: bool = False, replace: Any = None, is_leaf: Callable[[Any], bool] | None = None) -> PyTree
¤
Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with replace.
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]
Arguments:
pytreeis any PyTree.filter_specis a PyTree whose structure should be a prefix of the structure ofpytree. Each of its leaves should either be:True, in which case the leaf or subtree is kept;False, in which case the leaf or subtree is replaced withreplace;- a callable
Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
inverseswitches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.replaceis what to replace any falsey leaves with. Defaults toNone.is_leaf: Optional function called at each node of the PyTree. It should return a boolean.Trueindicates that the whole subtree should be treated as leaf;Falseindicates that the subtree should be traversed as a PyTree.
Returns:
A PyTree of the same structure as pytree.
equinox.is_array(element: Any) -> bool
¤
Returns True if element is a JAX array or NumPy array.
equinox.is_array_like(element: Any) -> bool
¤
Returns True if element is a JAX array, a NumPy array, or a Python
float/complex/bool/int.
equinox.is_inexact_array(element: Any) -> bool
¤
Returns True if element is an inexact (i.e. floating or complex) JAX/NumPy
array.
equinox.is_inexact_array_like(element: Any) -> bool
¤
Returns True if element is an inexact JAX array, an inexact NumPy array, or
a Python float or complex.
equinox.partition(pytree: PyTree, filter_spec: PyTree[bool | Callable[[Any], bool]], replace: Any = None, is_leaf: Callable[[Any], bool] | None = None) -> tuple[PyTree, PyTree]
¤
Splits a PyTree into two pieces. Equivalent to
filter(...), filter(..., inverse=True), but slightly more efficient.
Info
See also equinox.combine to reconstitute the PyTree again.
equinox.combine(*pytrees: PyTree, is_leaf: Callable[[Any], bool] | None = None) -> PyTree
¤
Combines multiple PyTrees into one PyTree, by replacing None leaves.
Example
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
Tip
The idea is that equinox.combine should be used to undo a call to
equinox.filter or equinox.partition.
Arguments:
*pytrees: a sequence of PyTrees all with the same structure.is_leaf: Asequinox.partition.
Returns:
A PyTree with the same structure as its inputs. Each leaf will be the first
non-None leaf found in the corresponding leaves of pytrees as they are
iterated over.
Common operations on PyTrees¤
equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree
¤
A jax.tree_util.tree_map-broadcasted version of
if update is None:
return model
else:
return model + update
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as optax.apply_updates, except
that it understands None.) For example see the
Train RNN example.
Arguments:
model: An arbitrary PyTree.updates: Any PyTree that is a prefix ofmodel.
Returns:
The updated model.
equinox.tree_at(where: Callable[[PyTree], Node | Sequence[Node]], pytree: PyTree, replace: Any | Sequence[Any] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Callable[[Any], bool] | None = None)
¤
Modifies a leaf or subtree of a PyTree. (A bit like using .at[].set() on a JAX
array.)
The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!
Arguments:
where: A callablePyTree -> NodeorPyTree -> tuple[Node, ...]. It should consume a PyTree with the same structure aspytree, and return the node or nodes that should be replaced. For examplewhere = lambda mlp: mlp.layers[-1].linear.weight.pytree: The PyTree to modify.replace: Either a single element, or a sequence of the same length as returned bywhere. This specifies the replacements to make at the locations specified bywhere. Mutually exclusive withreplace_fn.replace_fn: A functionNode -> Any. It will be called on every node specified bywhere. The return value fromreplace_fnwill be used in its place. Mutually exclusive withreplace.is_leaf: Asjtu.tree_flatten. For example passis_leaf=lambda x: x is Noneto be able to replaceNonevalues usingtree_at.
Note that where should not depend on the type of any of the leaves of the
pytree, e.g. given pytree = [1, 2, object(), 3], then
where = lambda x: tuple(xi for xi in x if type(xi) is int) is not allowed. If you
really need this behaviour then this example could instead be expressed as
where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int).
Returns:
A new PyTree with the same structure as the input PyTree and the appropriate modifications.
(If donating JAX arrays on JIT boundaries, then note that this function does not make a copy of the JAX arrays.)
Example
# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.
Example
This is useful for performing model surgery. For example:
mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
Info
Constructing analogous PyTrees, with the same structure but different leaves, is
very common in JAX: for example when constructing the in_axes argument to
jax.vmap.
To support this use-case, the returned PyTree is constructed without calling
__init__, __post_init__, or
__check_init__. This allows
for modifying leaves to be anything, regardless of the use of any custom
constructor or custom checks in the original PyTree.
equinox.tree_equal(*pytrees: PyTree, typematch: bool = False, rtol: Float[ArrayLike, ''] = 0.0, atol: Float[ArrayLike, ''] = 0.0) -> bool | Bool[Array, '']
¤
Returns True if all input PyTrees are equal. Every PyTree must have the same
structure, and all leaves must be equal.
- For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and
dtype.
- If
rtol=0andatol=0(the default) then all their values must be equal. Otherwise they must satisfy{j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol). - If
typematch=False(the default) then JAX and NumPy arrays are considered equal to each other. Iftypematch=Truethen JAX and NumPy are not considered equal to each other.
- If
- For non-arrays, if
typematch=False(the default) then equality is determined with justleaf1 == leaf2. Iftypematch=Truethentype(leaf1) == type(leaf2)is also required.
If used under JIT, and any JAX arrays are present, then this may return a tracer.
Use the idiom result = tree_equal(...) is True if you'd like to assert that the
result is statically true without dependence on the value of any traced arrays.
Arguments:
*pytrees: Any number of PyTrees each with any structure.typematch:Whether to additionally require that corresponding leaves should have the sametype(...)as each other.rtol: Used to determine thertolofjnp.allclose/np.allclosewhen comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.atol: Asrtol.
Returns:
A boolean, or bool-typed tracer.
equinox.Partial(equinox.Module)
¤
Like functools.partial, but treats the wrapped function, and partially-applied
args and kwargs, as a PyTree.
This is very much like jax.tree_util.Partial. The difference is that the JAX
version requires that func be specifically a function -- and will silently
misbehave if given any non-function callable, e.g. equinox.nn.MLP. In contrast
the Equinox version allows for arbitrary callables.
__init__(func: Callable[..., ~_Return], /, *args: Any, **kwargs: Any)
¤
Arguments:
func: the callable to partially apply.*args: any positional arguments to apply.**kwargs: any keyword arguments to apply.
__call__(*args: Any, **kwargs: Any) -> ~_Return
¤
Call the wrapped self.func.
Arguments:
*args: the arguments to apply. Passed after those arguments passed during__init__.**kwargs: any keyword arguments to apply.
Returns:
The result of the wrapped function.
Unusual operations on PyTrees¤
equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef]
¤
Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.
Arguments:
pytree: the PyTree node to flatten by one level.
Returns:
As jax.tree_util.tree_flatten: a list of leaves and a PyTreeDef.
Example
x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))
y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))
equinox.tree_check(pytree: Any) -> None
¤
Checks if the PyTree has no self-references, and if all non-leaf nodes are unique
Python objects. (For example something like x = [1]; y = [x, x] would fail as x
appears twice in the PyTree.)
Having unique non-leaf nodes isn't actually a requirement that JAX imposes, but it
will become true after passing through an operation like jax.{jit, grad, ...} (as
JAX copies the PyTree without trying to preserve identity). As such some users like
to use this function to assert that this invariant was already true prior to the
transform, as a way to avoid surprises.
Example
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is non-leaf!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is leaf
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is leaf
eqx.tree_check([None, None]) # passes, duplicate is leaf
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is non-leaf!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
Arguments:
pytree: the PyTree to check.
Returns:
Nothing.
Raises:
A ValueError if the PyTree is not well-formed.