Manipulating PyTrees¤
Tip
equinox.tree_at
is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.
Filtering, partitioning, combining¤
equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with replace
.
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]
Arguments:
pytree
is any PyTree.filter_spec
is a PyTree whose structure should be a prefix of the structure ofpytree
. Each of its leaves should either be:True
, in which case the leaf or subtree is kept;False
, in which case the leaf or subtree is replaced withreplace
;- a callable
Leaf -> bool
, in which case this is evaluated on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
inverse
switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.replace
is what to replace any falsey leaves with. Defaults toNone
.is_leaf
: Optional function called at each node of the PyTree. It should return a boolean.True
indicates that the whole subtree should be treated as leaf;False
indicates that the subtree should be traversed as a PyTree.
Returns:
A PyTree of the same structure as pytree
.
equinox.is_array(element: Any) -> bool
¤
Returns True
if element
is a JAX array or NumPy array.
equinox.is_array_like(element: Any) -> bool
¤
Returns True
if element
is a JAX array, a NumPy array, or a Python
float
/complex
/bool
/int
.
equinox.is_inexact_array(element: Any) -> bool
¤
Returns True
if element
is an inexact (i.e. floating or complex) JAX/NumPy
array.
equinox.is_inexact_array_like(element: Any) -> bool
¤
Returns True
if element
is an inexact JAX array, an inexact NumPy array, or
a Python float
or complex
.
equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree]
¤
Splits a PyTree into two pieces. Equivalent to
filter(...), filter(..., inverse=True)
, but slightly more efficient.
Info
See also equinox.combine
to reconstitute the PyTree again.
equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Combines multiple PyTrees into one PyTree, by replacing None
leaves.
Example
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
Tip
The idea is that equinox.combine
should be used to undo a call to
equinox.filter
or equinox.partition
.
Arguments:
*pytrees
: a sequence of PyTrees all with the same structure.is_leaf
: Asequinox.partition
.
Returns:
A PyTree with the same structure as its inputs. Each leaf will be the first
non-None
leaf found in the corresponding leaves of pytrees
as they are
iterated over.
Common operations on PyTrees¤
equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree
¤
A jax.tree_util.tree_map
-broadcasted version of
if update is None:
return model
else:
return model + update
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as optax.apply_updates
, except
that it understands None
.) For example see the
Train RNN example.
Arguments:
model
: An arbitrary PyTree.updates
: Any PyTree that is a prefix ofmodel
.
Returns:
The updated model.
equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)
¤
Modifies a leaf or subtree of a PyTree. (A bit like using .at[].set()
on a JAX
array.)
The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!
Arguments:
where
: A callablePyTree -> Node
orPyTree -> tuple[Node, ...]
. It should consume a PyTree with the same structure aspytree
, and return the node or nodes that should be replaced. For examplewhere = lambda mlp: mlp.layers[-1].linear.weight
.pytree
: The PyTree to modify.replace
: Either a single element, or a sequence of the same length as returned bywhere
. This specifies the replacements to make at the locations specified bywhere
. Mutually exclusive withreplace_fn
.replace_fn
: A functionNode -> Any
. It will be called on every node specified bywhere
. The return value fromreplace_fn
will be used in its place. Mutually exclusive withreplace
.is_leaf
: Asjtu.tree_flatten
. For example passis_leaf=lambda x: x is None
to be able to replaceNone
values usingtree_at
.
Note that where
should not depend on the type of any of the leaves of the
pytree, e.g. given pytree = [1, 2, object(), 3]
, then
where = lambda x: tuple(xi for xi in x if type(xi) is int)
is not allowed. If you
really need this behaviour then this example could instead be expressed as
where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int)
.
Returns:
A copy of the input PyTree, with the appropriate modifications.
Example
# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.
Example
This is useful for performing model surgery. For example:
mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
Info
Constructing analogous PyTrees, with the same structure but different leaves, is
very common in JAX: for example when constructing the in_axes
argument to
jax.vmap
.
To support this use-case, the returned PyTree is constructed without calling
__init__
, __post_init__
, or
__check_init__
. This allows
for modifying leaves to be anything, regardless of the use of any custom
constructor or custom checks in the original PyTree.
equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array]
¤
Returns True
if all input PyTrees are equal. Every PyTree must have the same
structure, and all leaves must be equal.
- For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and
dtype.
- If
rtol=0
andatol=0
(the default) then all their values must be equal. Otherwise they must satisfy{j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol)
. - If
typematch=False
(the default) then JAX and NumPy arrays are considered equal to each other. Iftypematch=True
then JAX and NumPy are not considered equal to each other.
- If
- For non-arrays, if
typematch=False
(the default) then equality is determined with justleaf1 == leaf2
. Iftypematch=True
thentype(leaf1) == type(leaf2)
is also required.
If used under JIT, and any JAX arrays are present, then this may return a tracer.
Use the idiom result = tree_equal(...) is True
if you'd like to assert that the
result is statically true without dependence on the value of any traced arrays.
Arguments:
*pytrees
: Any number of PyTrees each with any structure.typematch:
Whether to additionally require that corresponding leaves should have the sametype(...)
as each other.rtol
: Used to determine thertol
ofjnp.allclose
/np.allclose
when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.atol
: Asrtol
.
Returns:
A boolean, or bool-typed tracer.
equinox.Partial (Module)
¤
Like functools.partial
, but treats the wrapped function, and partially-applied
value: Any = field(static=True)
args and kwargs, as a PyTree.
This is very much like jax.tree_util.Partial
. The difference is that the JAX
version requires that func
be specifically a function -- and will silently
misbehave if given any non-function callable, e.g. equinox.nn.MLP
. In contrast
the Equinox version allows for arbitrary callables.
__init__(/, self, func, *args, **kwargs)
¤
Arguments:
func
: the callable to partially apply.*args
: any positional arguments to apply.**kwargs
: any keyword arguments to apply.
__call__(self, *args, **kwargs)
¤
Call the wrapped self.func
.
Arguments:
*args
: the arguments to apply. Passed after those arguments passed during__init__
.**kwargs
: any keyword arguments to apply.
Returns:
The result of the wrapped function.
Unusual operations on PyTrees¤
equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef]
¤
Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.
Arguments:
pytree
: the PyTree node to flatten by one level.
Returns:
As jax.tree_util.tree_flatten
: a list of leaves and a PyTreeDef
.
Example
x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))
y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))
equinox.tree_check(pytree: Any) -> None
¤
Checks if the PyTree has no self-references, and if all non-leaf nodes are unique
Python objects. (For example something like x = [1]; y = [x, x]
would fail as x
appears twice in the PyTree.)
Having unique non-leaf nodes isn't actually a requirement that JAX imposes, but it
will become true after passing through an operation like jax.{jit, grad, ...}
(as
JAX copies the PyTree without trying to preserve identity). As such some users like
to use this function to assert that this invariant was already true prior to the
transform, as a way to avoid surprises.
Example
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is non-leaf!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is leaf
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is leaf
eqx.tree_check([None, None]) # passes, duplicate is leaf
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is non-leaf!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
Arguments:
pytree
: the PyTree to check.
Returns:
Nothing.
Raises:
A ValueError
if the PyTree is not well-formed.