Advanced features¤
Creating your own dtypes¤
jaxtyping.AbstractDtype
¤
This is the base class of all dtypes. This can be used to create your own custom
collection of dtypes (analogous to Float
, Inexact
etc.)
You must specify the class attribute dtypes
. This can either be a string, a
regex (as returned by re.compile(...)
), or a tuple/list of strings/regexes.
At runtime, the array or tensor's dtype is converted to a string and compared
against the string (an exact match is required) or regex. (String matching is
performed, rather than just e.g. array.dtype == dtype
, to provide cross-library
compatibility between JAX/PyTorch/etc.)
Example
class UInt8or16(AbstractDtype):
dtypes = ["uint8", "uint16"]
UInt8or16[Array, "shape"]
Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]
jaxtyping.make_numpy_struct_dtype(dtype: np.dtype, name: str)
¤
Creates a type annotation for numpy structured array It performs an exact match on the name, order, and dtype of all its fields.
Example
label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
Label = make_numpy_struct_dtype(label_t, 'Label')
after that, you can use it just like any other jaxtyping.AbstractDtype
:
a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)
Arguments:
dtype
: The numpy structured dtype to use.name
: The name to use for the returned Python class.
Returns:
A type annotation with classname name
that matches exactly dtype
when used like
any other jaxtyping.AbstractDtype
.
Printing axis bindings¤
jaxtyping.print_bindings()
¤
Prints the values of the current jaxtyping axis bindings. Intended for debugging.
For example, this can be used to find the values bound to foo
and bar
in
@jaxtyped(typechecker=...)
def f(x: Float[Array, "foo bar"]):
print_bindings()
...
noting that these values are bounding during runtime typechecking, so that the
jaxtyping.jaxtyped
decorator is required.
Arguments:
Nothing.
Returns:
Nothing.
Introspection¤
If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
You can check for dtypes by doing issubclass(x, AbstractDtype)
. For example, issubclass(Float32, AbstractDtype)
will pass.
You can check for arrays by doing issubclass(x, AbstractArray)
. Here, AbstractArray
is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for Float32[Array, "foo"]
.
You can check for pytrees by doing issubclass(x, PyTree)
. For example, issubclass(PyTree[int], PyTree)
will pass.