Terminology for types¤
This API documentation uses a few convenient shorthands to refer some types.
Scalarrefers to either an
float, or a JAX array with shape
PyTreerefers to any PyTree.
Arrayrefers to a JAX array.
In addition shapes and dtypes of
Arrays are annotated:
Array["dim1", "dim2"]refers to a JAX array with shape
(dim1, dim2), and so on for other shapes.
- If a dimension is named in this way, then it should match up and be of equal size to the equally-named dimensions of all other arrays passed at the same time.
Array[()]refers to an array with shape
...refers to an arbitrary number of dimensions, e.g.
Array[bool]refers to a JAX array with Boolean dtype. (And so on for other dtypes.)
- These are combined via e.g.
Array["dim1", "dim2", bool].
- The above syntax is essentially inspired by torchtyping.
PyTree[T] is used to refer to a PyTree all of whose leaves have type