Skip to content

Manipulating PyTrees¤

Tip

equinox.tree_at is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.

Filtering, partitioning, combining¤

equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying the condition are replaced with replace.

Example

pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]

Example

pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]

Arguments:

  • pytree is any PyTree.
  • filter_spec is a PyTree whose structure should be a prefix of the structure of pytree. Each of its leaves should either be:
    • True, in which case the leaf or subtree is kept;
    • False, in which case the leaf or subtree is replaced with replace;
    • a callable Leaf -> bool, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
  • inverse switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.
  • replace is what to replace any falsey leaves with. Defaults to None.
  • is_leaf: Optional function called at each node of the PyTree. It should return a boolean. True indicates that the whole subtree should be treated as leaf; False indicates that the subtree should be traversed as a PyTree.

Returns:

A PyTree of the same structure as pytree.

equinox.is_array(element: Any) -> bool ¤

Returns True if element is a JAX array or NumPy array.

equinox.is_array_like(element: Any) -> bool ¤

Returns True if element is a JAX array, a NumPy array, or a Python float/complex/bool/int.

equinox.is_inexact_array(element: Any) -> bool ¤

Returns True if element is an inexact (i.e. floating or complex) JAX/NumPy array.

equinox.is_inexact_array_like(element: Any) -> bool ¤

Returns True if element is an inexact JAX array, an inexact NumPy array, or a Python float or complex.


equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree] ¤

Splits a PyTree into two pieces. Equivalent to filter(...), filter(..., inverse=True), but slightly more efficient.

Info

See also equinox.combine to reconstitute the PyTree again.


equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Combines multiple PyTrees into one PyTree, by replacing None leaves.

Example

pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2)  # [0, 1, 2]

Tip

The idea is that equinox.combine should be used to undo a call to equinox.filter or equinox.partition.

Arguments:

  • *pytrees: a sequence of PyTrees all with the same structure.
  • is_leaf: As equinox.partition.

Returns:

A PyTree with the same structure as its inputs. Each leaf will be the first non-None leaf found in the corresponding leaves of pytrees as they are iterated over.

Common operations on PyTrees¤

equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree ¤

A jax.tree_util.tree_map-broadcasted version of

if update is None:
    return model
else:
    return model + update

This is often useful when updating a model's parameters via stochastic gradient descent. (This function is essentially the same as optax.apply_updates, except that it understands None.) For example see the Train RNN example.

Arguments:

  • model: An arbitrary PyTree.
  • updates: Any PyTree that is a prefix of model.

Returns:

The updated model.


equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None) ¤

Modifies a leaf or subtree of a PyTree. (A bit like using .at[].set() on a JAX array.)

The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!

Arguments:

  • where: A callable PyTree -> Node or PyTree -> tuple[Node, ...]. It should consume a PyTree with the same structure as pytree, and return the node or nodes that should be replaced. For example where = lambda mlp: mlp.layers[-1].linear.weight.
  • pytree: The PyTree to modify.
  • replace: Either a single element, or a sequence of the same length as returned by where. This specifies the replacements to make at the locations specified by where. Mutually exclusive with replace_fn.
  • replace_fn: A function Node -> Any. It will be called on every node specified by where. The return value from replace_fn will be used in its place. Mutually exclusive with replace.
  • is_leaf: As jtu.tree_flatten. For example pass is_leaf=lambda x: x is None to be able to replace None values using tree_at.

Note that where should not depend on the type of any of the leaves of the pytree, e.g. given pytree = [1, 2, object(), 3], then where = lambda x: tuple(xi for xi in x if type(xi) is int) is not allowed. If you really need this behaviour then this example could instead be expressed as where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int).

Returns:

A copy of the input PyTree, with the appropriate modifications.

Example

# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.

Example

This is useful for performing model surgery. For example:

mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
See also the Tricks page.

Info

Constructing analogous PyTrees, with the same structure but different leaves, is very common in JAX: for example when constructing the in_axes argument to jax.vmap.

To support this use-case, the returned PyTree is constructed without calling __init__, __post_init__, or __check_init__. This allows for modifying leaves to be anything, regardless of the use of any custom constructor or custom checks in the original PyTree.


equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array] ¤

Returns True if all input PyTrees are equal. Every PyTree must have the same structure, and all leaves must be equal.

  • For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and dtype.
    • If rtol=0 and atol=0 (the default) then all their values must be equal. Otherwise they must satisfy {j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol).
    • If typematch=False (the default) then JAX and NumPy arrays are considered equal to each other. If typematch=True then JAX and NumPy are not considered equal to each other.
  • For non-arrays, if typematch=False (the default) then equality is determined with just leaf1 == leaf2. If typematch=True then type(leaf1) == type(leaf2) is also required.

If used under JIT, and any JAX arrays are present, then this may return a tracer. Use the idiom result = tree_equal(...) is True if you'd like to assert that the result is statically true without dependence on the value of any traced arrays.

Arguments:

  • *pytrees: Any number of PyTrees each with any structure.
  • typematch: Whether to additionally require that corresponding leaves should have the same type(...) as each other.
  • rtol: Used to determine the rtol of jnp.allclose/np.allclose when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.
  • atol: As rtol.

Returns:

A boolean, or bool-typed tracer.


equinox.Partial (Module) ¤

Like functools.partial, but treats the wrapped function, and partially-applied value: Any = field(static=True) args and kwargs, as a PyTree.

This is very much like jax.tree_util.Partial. The difference is that the JAX version requires that func be specifically a function -- and will silently misbehave if given any non-function callable, e.g. equinox.nn.MLP. In contrast the Equinox version allows for arbitrary callables.

__init__(/, self, func, *args, **kwargs) ¤

Arguments:

  • func: the callable to partially apply.
  • *args: any positional arguments to apply.
  • **kwargs: any keyword arguments to apply.
__call__(self, *args, **kwargs) ¤

Call the wrapped self.func.

Arguments:

  • *args: the arguments to apply. Passed after those arguments passed during __init__.
  • **kwargs: any keyword arguments to apply.

Returns:

The result of the wrapped function.

Unusual operations on PyTrees¤

equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef] ¤

Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.

Arguments:

  • pytree: the PyTree node to flatten by one level.

Returns:

As jax.tree_util.tree_flatten: a list of leaves and a PyTreeDef.

Example

x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))

y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))

equinox.tree_check(pytree: Any) -> None ¤

Checks if the PyTree has no self-references, and if all non-leaf nodes are unique Python objects. (For example something like x = [1]; y = [x, x] would fail as x appears twice in the PyTree.)

Having unique non-leaf nodes isn't actually a requirement that JAX imposes, but it will become true after passing through an operation like jax.{jit, grad, ...} (as JAX copies the PyTree without trying to preserve identity). As such some users like to use this function to assert that this invariant was already true prior to the transform, as a way to avoid surprises.

Example

a = 1
eqx.tree_check([a, a])  # passes, duplicate is a leaf

b = eqx.nn.Linear(...)
eqx.tree_check([b, b])  # fails, duplicate is non-leaf!

c = []  # empty list
eqx.tree_check([c, c])  # passes, duplicate is leaf

d = eqx.Module()
eqx.tree_check([d, d])  # passes, duplicate is leaf

eqx.tree_check([None, None])  # passes, duplicate is leaf

e = [1]
eqx.tree_check([e, e])  # fails, duplicate is non-leaf!

eqx.tree_check([[1], [1]])  # passes, not actually a duplicate: each `[1]`
                            # has the same structure, but they're different.

# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])

Arguments:

  • pytree: the PyTree to check.

Returns:

Nothing.

Raises:

A ValueError if the PyTree is not well-formed.