PyTree annotations¤
jaxtyping.PyTree
¤
Represents a PyTree.
Annotations of the following sorts are supported:
a: PyTree
b: PyTree[LeafType]
c: PyTree[LeafType, "T"]
d: PyTree[LeafType, "S T"]
e: PyTree[LeafType, "... T"]
f: PyTree[LeafType, "T ..."]
These correspond to:
a. A plain PyTree
can be used an annotation, in which case PyTree
is simply a
suggestively-named alternative to Any
.
(By definition all types are PyTrees.)
b. PyTree[LeafType]
denotes a PyTree all of whose leaves match LeafType
. For
example, PyTree[int]
or PyTree[Union[str, Float32[Array, "b c"]]]
.
c. A structure name can also be passed. In this case
jax.tree_util.tree_structure(...)
will be called, and bound to the structure name.
This can be used to mark that multiple PyTrees all have the same structure:
def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
...
jaxtyping.jaxtyped
decorator.
d. A composite structure can be declared. In this case the variable must have a PyTree structure each to the composition of multiple previously-bound PyTree structures. For example:
def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
...
x = (1, 2)
y = {"key": 3}
z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
f(x, y, z)
e. A structure can begin with a ...
, to denote that the lower levels of the PyTree
must match the declared structure, but the upper levels can be arbitrary. As in the
previous case, all named pieces must already have been seen and their structures
bound.
f. A structure can end with a ...
, to denote that the PyTree must be a prefix of the
declared structure, but the lower levels can be arbitrary. As in the previous two
cases, all named pieces must already have been seen and their structures bound.
jaxtyping.PyTreeDef
¤
Alias for jax.tree_util.PyTreeDef
, which is the type of the
return from jax.tree_util.tree_structure(...)
.
Path-dependent shapes¤
The prefix ?
may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example:
def f(
x: PyTree[Shaped[Array, "?foo"], "T"],
y: PyTree[Shaped[Array, "?foo"], "T"],
):
pass
x
and y
have matching PyTree structures (due to the T
annotation), and that their leaves must all be one-dimensional arrays, and that the corresponding pairs of leaves in x
and y
must have the same size as each other.
Thus the following is allowed:
x0 = jnp.arange(3)
x1 = jnp.arange(5)
y0 = jnp.arange(3) + 1
y1 = jnp.arange(5) + 1
f((x0, x1), (y0, y1)) # x0 matches y0, and x1 matches y1. All good!
But this is not:
f((x1, x1), (y0, y1)) # x1 does not have a size matching y0!
Internally, all that is happening is that foo
is replaced with 0foo
for the first leaf, 1foo
for the next leaf, etc., so that each leaf gets a unique version of the name.
Note that jaxtyping.{PyTree, PyTreeDef}
are only available if JAX has been installed.