Skip to content

Manipulating PyTrees¤

Tip

equinox.tree_at is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.

Info

In addition to these utilities, also see equinox.partition, equinox.combine, and equinox.filter.

equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree ¤

A jtu.tree_map-broadcasted version of

if update is None:
    return model
else:
    return model + update

This is often useful when updating a model's parameters via stochastic gradient descent. (This function is essentially the same as optax.apply_updates, except that it understands None.) For example see the Train RNN example.

Arguments:

  • model: An arbitrary PyTree.
  • updates: Any PyTree that is a prefix of model.

Returns:

The updated model.


equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None) ¤

Modifies a PyTree out-of-place. (A bit like using .at[].set() on a JAX array.)

Arguments:

  • where: A callable PyTree -> Node or PyTree -> tuple[Node, ...]. It should consume a PyTree with the same structure as pytree, and return the node or nodes that should be replaced. For example where = lambda mlp: mlp.layers[-1].linear.weight.
  • pytree: The PyTree to modify.
  • replace: Either a single element, or a sequence of the same length as returned by where. This specifies the replacements to make at the locations specified by where. Mutually exclusive with replace_fn.
  • replace_fn: A function Node -> Any. It will be called on every node specified by where. The return value from replace_fn will be used in its place. Mutually exclusive with replace.
  • is_leaf: As jtu.tree_flatten. For example pass is_leaf=lambda x: x is None to be able to replace None values using tree_at.

Note that where should not depend on the type of any of the leaves of the pytree, e.g. given pytree = [1, 2, object(), 3], then where = lambda x: tuple(xi for xi in x if type(xi) is int) is not allowed. If you really need this behaviour then this example could instead be expressed as where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int).

Returns:

A copy of the input PyTree, with the appropriate modifications.

Example

# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.

Example

This is useful for performing model surgery. For example:

mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
See also the Tricks page.


equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array] ¤

Returns True if all input PyTrees are equal. Every PyTree must have the same structure, and all leaves must be equal.

  • For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and dtype.
    • If rtol=0 and atol=0 (the default) then all their values must be equal. Otherwise they must satisfy {j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol).
    • If typematch=False (the default) then JAX and NumPy arrays are considered equal to each other. If typematch=True then JAX and NumPy are not considered equal to each other.
  • For non-arrays, if typematch=False (the default) then equality is determined with just leaf1 == leaf2. If typematch=True then type(leaf1) == type(leaf2) is also required.

If used under JIT, and any JAX arrays are present, then this may return a tracer. Use the idiom result = tree_equal(...) is True if you'd like to assert that the result is statically true without dependence on the value of any traced arrays.

Arguments:

  • *pytrees: Any number of PyTrees each with any structure.
  • typematch: Whether to additionally require that corresponding leaves should have the same type(...) as each other.
  • rtol: Used to determine the rtol of jnp.allclose/np.allclose when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.
  • atol: As rtol.

Returns:

A boolean, or bool-typed tracer.


equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef] ¤

Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.

Arguments:

  • pytree: the PyTree node to flatten by one level.

Returns:

As jax.tree_util.tree_flatten: a list of leaves and a PyTreeDef.

Example

x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))

y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))

equinox.tree_check(pytree: Any) -> None ¤

Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers.

Precisely, a "duplicate layer" is any PyTree node with at least one child node.

Info

This is automatically called when creating an eqx.Module instance, to help avoid bugs from duplicating layers.

Example

a = 1
eqx.tree_check([a, a])  # passes, duplicate is a leaf

b = eqx.nn.Linear(...)
eqx.tree_check([b, b])  # fails, duplicate is nontrivial!

c = []  # empty list
eqx.tree_check([c, c])  # passes, duplicate is trivial

d = eqx.Module()
eqx.tree_check([d, d])  # passes, duplicate is trivial

eqx.tree_check([None, None])  # passes, duplicate is trivial

e = [1]
eqx.tree_check([e, e])  # fails, duplicate is nontrivial!

eqx.tree_check([[1], [1]])  # passes, not actually a duplicate: each `[1]`
                            # has the same structure, but they're different.

# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])

Arguments:

  • pytree: the PyTree to check.

Returns:

Nothing.

Raises:

A ValueError if the PyTree is not well-formed.


equinox.Partial (Module) ¤

Like functools.partial, but treats the wrapped function, and partially-applied value: Any = field(static=True) args and kwargs, as a PyTree.

This is very much like jax.tree_util.Partial. The difference is that the JAX version requires that func be specifically a function -- and will silently misbehave if given any non-function callable, e.g. equinox.nn.MLP. In contrast the Equinox version allows for arbitrary callables.

__init__(/, self, func, *args, **kwargs) ¤

Arguments:

  • func: the callable to partially apply.
  • *args: any positional arguments to apply.
  • **kwargs: any keyword arguments to apply.
__call__(self, *args, **kwargs) ¤

Call the wrapped self.func.

Arguments:

  • *args: the arguments to apply. Passed after those arguments passed during __init__.
  • **kwargs: any keyword arguments to apply.

Returns:

The result of the wrapped function.