# PyTree annotations¤

####
```
jaxtyping.PyTree
```

¤

Represents a PyTree.

Annotations of the following sorts are supported:

```
a: PyTree
b: PyTree[LeafType]
c: PyTree[LeafType, "T"]
d: PyTree[LeafType, "S T"]
e: PyTree[LeafType, "... T"]
f: PyTree[LeafType, "T ..."]
```

These correspond to:

a. A plain `PyTree`

can be used an annotation, in which case `PyTree`

is simply a
suggestively-named alternative to `Any`

.
(By definition all types are PyTrees.)

b. `PyTree[LeafType]`

denotes a PyTree all of whose leaves match `LeafType`

. For
example, `PyTree[int]`

or `PyTree[Union[str, Float32[Array, "b c"]]]`

.

c. A structure name can also be passed. In this case
`jax.tree_util.tree_structure(...)`

will be called, and bound to the structure name.
This can be used to mark that multiple PyTrees all have the same structure:

```
def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
...
```

`jaxtyping.jaxtyped`

decorator.
d. A composite structure can be declared. In this case the variable must have a PyTree structure each to the composition of multiple previously-bound PyTree structures. For example:

```
def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
...
x = (1, 2)
y = {"key": 3}
z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
f(x, y, z)
```

e. A structure can begin with a `...`

, to denote that the lower levels of the PyTree
must match the declared structure, but the upper levels can be arbitrary. As in the
previous case, all named pieces must already have been seen and their structures
bound.

f. A structure can end with a `...`

, to denote that the PyTree must be a prefix of the
declared structure, but the lower levels can be arbitrary. As in the previous two
cases, all named pieces must already have been seen and their structures bound.

####
```
jaxtyping.PyTreeDef
```

¤

Alias for `jax.tree_util.PyTreeDef`

, which is the type of the
return from `jax.tree_util.tree_structure(...)`

.

## Path-dependent shapes¤

The prefix `?`

may be used to indicate that the axis size can depend on which leaf of a PyTree the array is at. For example:

```
def f(
x: PyTree[Shaped[Array, "?foo"], "T"],
y: PyTree[Shaped[Array, "?foo"], "T"],
):
pass
```

`x`

and `y`

have matching PyTree structures (due to the `T`

annotation), and that their leaves must all be one-dimensional arrays, *and that the corresponding pairs of leaves in*.

`x`

and `y`

must have the same size as each otherThus the following is allowed:

```
x0 = jnp.arange(3)
x1 = jnp.arange(5)
y0 = jnp.arange(3) + 1
y1 = jnp.arange(5) + 1
f((x0, x1), (y0, y1)) # x0 matches y0, and x1 matches y1. All good!
```

But this is not:

```
f((x1, x1), (y0, y1)) # x1 does not have a size matching y0!
```

Internally, all that is happening is that `foo`

is replaced with `0foo`

for the first leaf, `1foo`

for the next leaf, etc., so that each leaf gets a unique version of the name.

Note that `jaxtyping.{PyTree, PyTreeDef}`

are only available if JAX has been installed.