Skip to content

Advanced features¤

Creating your own dtypes¤

jaxtyping.AbstractDtype ¤

This is the base class of all dtypes. This can be used to create your own custom collection of dtypes (analogous to Float, Inexact etc.)

You must specify the class attribute dtypes. This can either be a string, a regex (as returned by re.compile(...)), or a tuple/list of strings/regexes.

At runtime, the array or tensor's dtype is converted to a string and compared against the string (an exact match is required) or regex. (String matching is performed, rather than just e.g. array.dtype == dtype, to provide cross-library compatibility between JAX/PyTorch/etc.)

Example

class UInt8or16(AbstractDtype):
    dtypes = ["uint8", "uint16"]

UInt8or16[Array, "shape"]
which is essentially equivalent to
Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]

jaxtyping.make_numpy_struct_dtype(dtype: np.dtype, name: str) ¤

Creates a type annotation for numpy structured array It performs an exact match on the name, order, and dtype of all its fields.

Example

label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
Label = make_numpy_struct_dtype(label_t, 'Label')

after that, you can use it just like any other jaxtyping.AbstractDtype:

a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)

Arguments:

  • dtype: The numpy structured dtype to use.
  • name: The name to use for the returned Python class.

Returns:

A type annotation with classname name that matches exactly dtype when used like any other jaxtyping.AbstractDtype.

Printing axis bindings¤

jaxtyping.print_bindings() ¤

Prints the values of the current jaxtyping axis bindings. Intended for debugging.

For example, this can be used to find the values bound to foo and bar in

@jaxtyped(typechecker=...)
def f(x: Float[Array, "foo bar"]):
    print_bindings()
    ...

noting that these values are bounding during runtime typechecking, so that the jaxtyping.jaxtyped decorator is required.

Arguments:

Nothing.

Returns:

Nothing.

Introspection¤

If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.

You can check for dtypes by doing issubclass(x, AbstractDtype). For example, issubclass(Float32, AbstractDtype) will pass.

You can check for arrays by doing issubclass(x, AbstractArray). Here, AbstractArray is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for Float32[Array, "foo"].

You can check for pytrees by doing issubclass(x, PyTree). For example, issubclass(PyTree[int], PyTree) will pass.