Skip to content

PyTree annotations¤

jaxtyping.PyTree ¤

Represents a PyTree.

Annotations of the following sorts are supported:

a: PyTree
b: PyTree[LeafType]
c: PyTree[LeafType, "T"]
d: PyTree[LeafType, "S T"]
e: PyTree[LeafType, "... T"]
f: PyTree[LeafType, "T ..."]

These correspond to:

a. A plain PyTree can be used an annotation, in which case PyTree is simply a suggestively-named alternative to Any. (By definition all types are PyTrees.)

b. PyTree[LeafType] denotes a PyTree all of whose leaves match LeafType. For example, PyTree[int] or PyTree[Union[str, Float32[Array, "b c"]]].

c. A structure name can also be passed. In this case jax.tree_util.tree_structure(...) will be called, and bound to the structure name. This can be used to mark that multiple PyTrees all have the same structure:

def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
    ...
Structures are bound to names in the same way as array shape annotations, i.e. within the thread-local dynamic context of a jaxtyping.jaxtyped decorator.

d. A composite structure can be declared. In this case the variable must have a PyTree structure each to the composition of multiple previously-bound PyTree structures. For example:

def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
    ...

x = (1, 2)
y = {"key": 3}
z = {"key": (4, 5)}  # structure is the composition of the structures of `y` and `z`
f(x, y, z)
When performing runtime type-checking, all the individual pieces must have already been bound to structures, otherwise the composite structure check will throw an error.

e. A structure can begin with a ..., to denote that the lower levels of the PyTree must match the declared structure, but the upper levels can be arbitrary. As in the previous case, all named pieces must already have been seen and their structures bound.

f. A structure can end with a ..., to denote that the PyTree must be a prefix of the declared structure, but the lower levels can be arbitrary. As in the previous two cases, all named pieces must already have been seen and their structures bound.


jaxtyping.PyTreeDef ¤

Alias for jax.tree_util.PyTreeDef, which is the type of the return from jax.tree_util.tree_structure(...).


Path-dependent shapes¤

The prefix ? may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example:

def f(
    x: PyTree[Shaped[Array, "?foo"], "T"],
    y: PyTree[Shaped[Array, "?foo"], "T"],
):
    pass
The above demands that x and y have matching PyTree structures (due to the T annotation), and that their leaves must all be one-dimensional arrays, and that the corresponding pairs of leaves in x and y must have the same size as each other.

Thus the following is allowed:

x0 = jnp.arange(3)
x1 = jnp.arange(5)

y0 = jnp.arange(3) + 1
y1 = jnp.arange(5) + 1

f((x0, x1), (y0, y1))  # x0 matches y0, and x1 matches y1. All good!

But this is not:

f((x1, x1), (y0, y1))  # x1 does not have a size matching y0!

Internally, all that is happening is that foo is replaced with 0foo for the first leaf, 1foo for the next leaf, etc., so that each leaf gets a unique version of the name.


Note that jaxtyping.{PyTree, PyTreeDef} are only available if JAX has been installed.