Skip to content

Manipulating PyTrees¤

Tip

equinox.tree_at is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.

Filtering, partitioning, combining¤

equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying the condition are replaced with replace.

Example

pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]

Example

pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]

Arguments:

  • pytree is any PyTree.
  • filter_spec is a PyTree whose structure should be a prefix of the structure of pytree. Each of its leaves should either be:
    • True, in which case the leaf or subtree is kept;
    • False, in which case the leaf or subtree is replaced with replace;
    • a callable Leaf -> bool, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
  • inverse switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.
  • replace is what to replace any falsey leaves with. Defaults to None.
  • is_leaf: Optional function called at each node of the PyTree. It should return a boolean. True indicates that the whole subtree should be treated as leaf; False indicates that the subtree should be traversed as a PyTree.

Returns:

A PyTree of the same structure as pytree.

equinox.is_array(element: Any) -> bool ¤

Returns True if element is a JAX array or NumPy array.

equinox.is_array_like(element: Any) -> bool ¤

Returns True if element is a JAX array, a NumPy array, or a Python float/complex/bool/int.

equinox.is_inexact_array(element: Any) -> bool ¤

Returns True if element is an inexact (i.e. floating or complex) JAX/NumPy array.

equinox.is_inexact_array_like(element: Any) -> bool ¤

Returns True if element is an inexact JAX array, an inexact NumPy array, or a Python float or complex.


equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree] ¤

Splits a PyTree into two pieces. Equivalent to filter(...), filter(..., inverse=True), but slightly more efficient.

Info

See also equinox.combine to reconstitute the PyTree again.


equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Combines multiple PyTrees into one PyTree, by replacing None leaves.

Example

pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2)  # [0, 1, 2]

Tip

The idea is that equinox.combine should be used to undo a call to equinox.filter or equinox.partition.

Arguments:

  • *pytrees: a sequence of PyTrees all with the same structure.
  • is_leaf: As equinox.partition.

Returns:

A PyTree with the same structure as its inputs. Each leaf will be the first non-None leaf found in the corresponding leaves of pytrees as they are iterated over.

Common operations on PyTrees¤

equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree ¤

A jax.tree_util.tree_map-broadcasted version of

if update is None:
    return model
else:
    return model + update

This is often useful when updating a model's parameters via stochastic gradient descent. (This function is essentially the same as optax.apply_updates, except that it understands None.) For example see the Train RNN example.

Arguments:

  • model: An arbitrary PyTree.
  • updates: Any PyTree that is a prefix of model.

Returns:

The updated model.


equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None) ¤

Modifies a leaf or subtree of a PyTree. (A bit like using .at[].set() on a JAX array.)

The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!

Arguments:

  • where: A callable PyTree -> Node or PyTree -> tuple[Node, ...]. It should consume a PyTree with the same structure as pytree, and return the node or nodes that should be replaced. For example where = lambda mlp: mlp.layers[-1].linear.weight.
  • pytree: The PyTree to modify.
  • replace: Either a single element, or a sequence of the same length as returned by where. This specifies the replacements to make at the locations specified by where. Mutually exclusive with replace_fn.
  • replace_fn: A function Node -> Any. It will be called on every node specified by where. The return value from replace_fn will be used in its place. Mutually exclusive with replace.
  • is_leaf: As jtu.tree_flatten. For example pass is_leaf=lambda x: x is None to be able to replace None values using tree_at.

Note that where should not depend on the type of any of the leaves of the pytree, e.g. given pytree = [1, 2, object(), 3], then where = lambda x: tuple(xi for xi in x if type(xi) is int) is not allowed. If you really need this behaviour then this example could instead be expressed as where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int).

Returns:

A copy of the input PyTree, with the appropriate modifications.

Example

# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.

Example

This is useful for performing model surgery. For example:

mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
See also the Tricks page.


equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array] ¤

Returns True if all input PyTrees are equal. Every PyTree must have the same structure, and all leaves must be equal.

  • For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and dtype.
    • If rtol=0 and atol=0 (the default) then all their values must be equal. Otherwise they must satisfy {j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol).
    • If typematch=False (the default) then JAX and NumPy arrays are considered equal to each other. If typematch=True then JAX and NumPy are not considered equal to each other.
  • For non-arrays, if typematch=False (the default) then equality is determined with just leaf1 == leaf2. If typematch=True then type(leaf1) == type(leaf2) is also required.

If used under JIT, and any JAX arrays are present, then this may return a tracer. Use the idiom result = tree_equal(...) is True if you'd like to assert that the result is statically true without dependence on the value of any traced arrays.

Arguments:

  • *pytrees: Any number of PyTrees each with any structure.
  • typematch: Whether to additionally require that corresponding leaves should have the same type(...) as each other.
  • rtol: Used to determine the rtol of jnp.allclose/np.allclose when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.
  • atol: As rtol.

Returns:

A boolean, or bool-typed tracer.


equinox.Partial (Module) ¤

Like functools.partial, but treats the wrapped function, and partially-applied value: Any = field(static=True) args and kwargs, as a PyTree.

This is very much like jax.tree_util.Partial. The difference is that the JAX version requires that func be specifically a function -- and will silently misbehave if given any non-function callable, e.g. equinox.nn.MLP. In contrast the Equinox version allows for arbitrary callables.

__init__(/, self, func, *args, **kwargs) ¤

Arguments:

  • func: the callable to partially apply.
  • *args: any positional arguments to apply.
  • **kwargs: any keyword arguments to apply.
__call__(self, *args, **kwargs) ¤

Call the wrapped self.func.

Arguments:

  • *args: the arguments to apply. Passed after those arguments passed during __init__.
  • **kwargs: any keyword arguments to apply.

Returns:

The result of the wrapped function.

Unusual operations on PyTrees¤

equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef] ¤

Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.

Arguments:

  • pytree: the PyTree node to flatten by one level.

Returns:

As jax.tree_util.tree_flatten: a list of leaves and a PyTreeDef.

Example

x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))

y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))

equinox.tree_check(pytree: Any) -> None ¤

Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers.

Precisely, a "duplicate layer" is any PyTree node with at least one child node.

Info

This is automatically called when creating an eqx.Module instance, to help avoid bugs from duplicating layers.

Example

a = 1
eqx.tree_check([a, a])  # passes, duplicate is a leaf

b = eqx.nn.Linear(...)
eqx.tree_check([b, b])  # fails, duplicate is nontrivial!

c = []  # empty list
eqx.tree_check([c, c])  # passes, duplicate is trivial

d = eqx.Module()
eqx.tree_check([d, d])  # passes, duplicate is trivial

eqx.tree_check([None, None])  # passes, duplicate is trivial

e = [1]
eqx.tree_check([e, e])  # fails, duplicate is nontrivial!

eqx.tree_check([[1], [1]])  # passes, not actually a duplicate: each `[1]`
                            # has the same structure, but they're different.

# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])

Arguments:

  • pytree: the PyTree to check.

Returns:

Nothing.

Raises:

A ValueError if the PyTree is not well-formed.