Manipulating PyTrees¤
Tip
equinox.tree_at
is one of the most useful utilities available here, which allows for performing surgery (out-of-place updates) on PyTrees.
Filtering, partitioning, combining¤
equinox.filter(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], inverse: bool = False, replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Filters out the leaves of a PyTree not satisfying a condition. Those not satisfying
the condition are replaced with replace
.
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, eqx.is_array)
# [(jnp.array(0), None), None]
Example
pytree = [(jnp.array(0), 1), object()]
result = eqx.filter(pytree, [(False, False), True])
# [(None, None), object()]
Arguments:
pytree
is any PyTree.filter_spec
is a PyTree whose structure should be a prefix of the structure ofpytree
. Each of its leaves should either be:True
, in which case the leaf or subtree is kept;False
, in which case the leaf or subtree is replaced withreplace
;- a callable
Leaf -> bool
, in which case this is evaluted on the leaf or mapped over the subtree, and the leaf kept or replaced as appropriate.
inverse
switches the truthy/falsey behaviour: falsey results are kept and truthy results are replaced.replace
is what to replace any falsey leaves with. Defaults toNone
.is_leaf
: Optional function called at each node of the PyTree. It should return a boolean.True
indicates that the whole subtree should be treated as leaf;False
indicates that the subtree should be traversed as a PyTree.
Returns:
A PyTree of the same structure as pytree
.
equinox.is_array(element: Any) -> bool
¤
Returns True
if element
is a JAX array or NumPy array.
equinox.is_array_like(element: Any) -> bool
¤
Returns True
if element
is a JAX array, a NumPy array, or a Python
float
/complex
/bool
/int
.
equinox.is_inexact_array(element: Any) -> bool
¤
Returns True
if element
is an inexact (i.e. floating or complex) JAX/NumPy
array.
equinox.is_inexact_array_like(element: Any) -> bool
¤
Returns True
if element
is an inexact JAX array, an inexact NumPy array, or
a Python float
or complex
.
equinox.partition(pytree: PyTree, filter_spec: PyTree[typing.Union[bool, Callable[[typing.Any], bool]]], replace: Any = None, is_leaf: Optional[Callable[[Any], bool]] = None) -> tuple[PyTree, PyTree]
¤
Splits a PyTree into two pieces. Equivalent to
filter(...), filter(..., inverse=True)
, but slightly more efficient.
Info
See also equinox.combine
to reconstitute the PyTree again.
equinox.combine(*pytrees: PyTree, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Combines multiple PyTrees into one PyTree, by replacing None
leaves.
Example
pytree1 = [None, 1, 2]
pytree2 = [0, None, None]
equinox.combine(pytree1, pytree2) # [0, 1, 2]
Tip
The idea is that equinox.combine
should be used to undo a call to
equinox.filter
or equinox.partition
.
Arguments:
*pytrees
: a sequence of PyTrees all with the same structure.is_leaf
: Asequinox.partition
.
Returns:
A PyTree with the same structure as its inputs. Each leaf will be the first
non-None
leaf found in the corresponding leaves of pytrees
as they are
iterated over.
Common operations on PyTrees¤
equinox.apply_updates(model: PyTree, updates: PyTree) -> PyTree
¤
A jax.tree_util.tree_map
-broadcasted version of
if update is None:
return model
else:
return model + update
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as optax.apply_updates
, except
that it understands None
.) For example see the
Train RNN example.
Arguments:
model
: An arbitrary PyTree.updates
: Any PyTree that is a prefix ofmodel
.
Returns:
The updated model.
equinox.tree_at(where: Callable[[PyTree], Union[Node, Sequence[Node]]], pytree: PyTree, replace: Union[Any, Sequence[Any]] = sentinel, replace_fn: Callable[[Node], Any] = sentinel, is_leaf: Optional[Callable[[Any], bool]] = None)
¤
Modifies a leaf or subtree of a PyTree. (A bit like using .at[].set()
on a JAX
array.)
The modified PyTree is returned and the original input is left unchanged. Make sure to use the return value from this function!
Arguments:
where
: A callablePyTree -> Node
orPyTree -> tuple[Node, ...]
. It should consume a PyTree with the same structure aspytree
, and return the node or nodes that should be replaced. For examplewhere = lambda mlp: mlp.layers[-1].linear.weight
.pytree
: The PyTree to modify.replace
: Either a single element, or a sequence of the same length as returned bywhere
. This specifies the replacements to make at the locations specified bywhere
. Mutually exclusive withreplace_fn
.replace_fn
: A functionNode -> Any
. It will be called on every node specified bywhere
. The return value fromreplace_fn
will be used in its place. Mutually exclusive withreplace
.is_leaf
: Asjtu.tree_flatten
. For example passis_leaf=lambda x: x is None
to be able to replaceNone
values usingtree_at
.
Note that where
should not depend on the type of any of the leaves of the
pytree, e.g. given pytree = [1, 2, object(), 3]
, then
where = lambda x: tuple(xi for xi in x if type(xi) is int)
is not allowed. If you
really need this behaviour then this example could instead be expressed as
where = lambda x: tuple(xi for xi, yi in zip(x, pytree) if type(yi) is int)
.
Returns:
A copy of the input PyTree, with the appropriate modifications.
Example
# Here is a pytree
tree = [1, [2, {"a": 3, "b": 4}]]
new_leaf = 5
get_leaf = lambda t: t[1][1]["a"]
new_tree = eqx.tree_at(get_leaf, tree, 5)
# new_tree is [1, [2, {"a": 5, "b": 4}]]
# The original tree is unchanged.
Example
This is useful for performing model surgery. For example:
mlp = eqx.nn.MLP(...)
new_linear = eqx.nn.Linear(...)
get_last_layer = lambda m: m.layers[-1]
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
equinox.tree_equal(*pytrees: PyTree, *, typematch: bool = False, rtol: ArrayLike = 0.0, atol: ArrayLike = 0.0) -> Union[bool, Array]
¤
Returns True
if all input PyTrees are equal. Every PyTree must have the same
structure, and all leaves must be equal.
- For JAX arrays, NumPy arrays, or NumPy scalars: they must have the same shape and
dtype.
- If
rtol=0
andatol=0
(the default) then all their values must be equal. Otherwise they must satisfy{j}np.allclose(leaf1, leaf2, rtol=rtol, atol=atol)
. - If
typematch=False
(the default) then JAX and NumPy arrays are considered equal to each other. Iftypematch=True
then JAX and NumPy are not considered equal to each other.
- If
- For non-arrays, if
typematch=False
(the default) then equality is determined with justleaf1 == leaf2
. Iftypematch=True
thentype(leaf1) == type(leaf2)
is also required.
If used under JIT, and any JAX arrays are present, then this may return a tracer.
Use the idiom result = tree_equal(...) is True
if you'd like to assert that the
result is statically true without dependence on the value of any traced arrays.
Arguments:
*pytrees
: Any number of PyTrees each with any structure.typematch:
Whether to additionally require that corresponding leaves should have the sametype(...)
as each other.rtol
: Used to determine thertol
ofjnp.allclose
/np.allclose
when comparing inexact (floating or complex) arrays. Defaults to zero, i.e. requires exact equality.atol
: Asrtol
.
Returns:
A boolean, or bool-typed tracer.
equinox.Partial (Module)
¤
Like functools.partial
, but treats the wrapped function, and partially-applied
value: Any = field(static=True)
args and kwargs, as a PyTree.
This is very much like jax.tree_util.Partial
. The difference is that the JAX
version requires that func
be specifically a function -- and will silently
misbehave if given any non-function callable, e.g. equinox.nn.MLP
. In contrast
the Equinox version allows for arbitrary callables.
__init__(/, self, func, *args, **kwargs)
¤
Arguments:
func
: the callable to partially apply.*args
: any positional arguments to apply.**kwargs
: any keyword arguments to apply.
__call__(self, *args, **kwargs)
¤
Call the wrapped self.func
.
Arguments:
*args
: the arguments to apply. Passed after those arguments passed during__init__
.**kwargs
: any keyword arguments to apply.
Returns:
The result of the wrapped function.
Unusual operations on PyTrees¤
equinox.tree_flatten_one_level(pytree: PyTree) -> tuple[list[PyTree], PyTreeDef]
¤
Returns the immediate subnodes of a PyTree node. If called on a leaf node then it will return just that leaf.
Arguments:
pytree
: the PyTree node to flatten by one level.
Returns:
As jax.tree_util.tree_flatten
: a list of leaves and a PyTreeDef
.
Example
x = {"a": 3, "b": (1, 2)}
eqx.tree_flatten_one_level(x)
# ([3, (1, 2)], PyTreeDef({'a': *, 'b': *}))
y = 4
eqx.tree_flatten_one_level(y)
# ([4], PyTreeDef(*))
equinox.tree_check(pytree: Any) -> None
¤
Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers.
Precisely, a "duplicate layer" is any PyTree node with at least one child node.
Info
This is automatically called when creating an eqx.Module
instance, to help
avoid bugs from duplicating layers.
Example
a = 1
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is nontrivial!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is trivial
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is trivial
eqx.tree_check([None, None]) # passes, duplicate is trivial
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is nontrivial!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
# passes, not actually a duplicate: each Linear layer is a separate layer.
eqx.tree_check([eqx.nn.Linear(...), eqx.nn.Linear(...)])
Arguments:
pytree
: the PyTree to check.
Returns:
Nothing.
Raises:
A ValueError
if the PyTree is not well-formed.