Skip to content

Advanced features¤

Creating your own dtypes¤

jaxtyping.AbstractDtype ¤

This is the base class of all dtypes. This can be used to create your own custom collection of dtypes (analogous to Float, Inexact etc.)

You must specify the class attribute dtypes. This can either be a string, a regex (as returned by re.compile(...)), or a tuple/list of strings/regexes.

At runtime, the array or tensor's dtype is converted to a string and compared against the string (an exact match is required) or regex. (String matching is performed, rather than just e.g. array.dtype == dtype, to provide cross-library compatibility between JAX/PyTorch/etc.)


class UInt8or16(AbstractDtype):
    dtypes = ["uint8", "uint16"]

UInt8or16[Array, "shape"]
which is essentially equivalent to
Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]

Printing axis bindings¤

jaxtyping.print_bindings() ¤

Prints the values of the current jaxtyping axis bindings. Intended for debugging.

That is, whilst doing runtime type checking, so that e.g. the foo and bar of Float[Array, "foo bar"] are assigned values -- this function will print out those values.






If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.

You can check for dtypes by doing issubclass(x, AbstractDtype). For example, issubclass(Float32, AbstractDtype) will pass.

You can check for arrays by doing issubclass(x, AbstractArray). Here, AbstractArray is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for Float32[Array, "foo"].

You can check for pytrees by doing issubclass(x, PyTree). For example, issubclass(PyTree[int], PyTree) will pass.