Advanced features¤
Creating your own dtypes¤
jaxtyping.AbstractDtype
¤
This is the base class of all dtypes. This can be used to create your own custom
collection of dtypes (analogous to Float
, Inexact
etc.)
You must specify the class attribute dtypes
. This can either be a string, a
regex (as returned by re.compile(...)
), or a tuple/list of strings/regexes.
At runtime, the array or tensor's dtype is converted to a string and compared
against the string (an exact match is required) or regex. (String matching is
performed, rather than just e.g. array.dtype == dtype
, to provide cross-library
compatibility between JAX/PyTorch/etc.)
Example
class UInt8or16(AbstractDtype):
dtypes = ["uint8", "uint16"]
UInt8or16[Array, "shape"]
Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]
Introspection¤
If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
You can check for dtypes by doing issubclass(x, AbstractDtype)
. For example, issubclass(Float32, AbstractDtype)
will pass.
You can check for arrays by doing issubclass(x, AbstractArray)
. Here, AbstractArray
is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for Float32[Array, "foo"]
.
You can check for pytrees by doing issubclass(x, PyTree)
. For example, issubclass(PyTree[int], PyTree)
will pass.