Advanced features¤
Creating your own dtypes¤
jaxtyping.AbstractDtype
¤
This is the base class of all dtypes. This can be used to create your own custom
collection of dtypes (analogous to Float, Inexact etc.)
You must specify the class attribute dtypes. This can either be a string, a
regex (as returned by re.compile(...)), or a tuple/list of strings/regexes.
At runtime, the array or tensor's dtype is converted to a string and compared
against the string (an exact match is required) or regex. (String matching is
performed, rather than just e.g. array.dtype == dtype, to provide cross-library
compatibility between JAX/PyTorch/etc.)
Example
class UInt8or16(AbstractDtype):
dtypes = ["uint8", "uint16"]
UInt8or16[Array, "shape"]
Union[UInt8[Array, "shape"], UInt16[Array, "shape"]]
jaxtyping.make_numpy_struct_dtype(dtype: numpy.dtype, name: str)
¤
Creates a type annotation for numpy structured array It performs an exact match on the name, order, and dtype of all its fields.
Example
label_t = np.dtype([('first', np.uint8), ('second', np.int8)])
Label = make_numpy_struct_dtype(label_t, 'Label')
after that, you can use it just like any other jaxtyping.AbstractDtype:
a: Label[np.ndarray, 'a b'] = np.array([[(1, 0), (0, 1)]], dtype=label_t)
Arguments:
dtype: The numpy structured dtype to use.name: The name to use for the returned Python class.
Returns:
A type annotation with classname name that matches exactly dtype when used like
any other jaxtyping.AbstractDtype.
Printing axis bindings¤
jaxtyping.print_bindings()
¤
Prints the values of the current jaxtyping axis bindings. Intended for debugging.
For example, this can be used to find the values bound to foo and bar in
@jaxtyped(typechecker=...)
def f(x: Float[Array, "foo bar"]):
print_bindings()
...
noting that these values are bounding during runtime typechecking, so that the
jaxtyping.jaxtyped decorator is required.
Arguments:
Nothing.
Returns:
Nothing.
Introspection¤
If you're writing your own type hint parser, then you may wish to detect if some Python object is a jaxtyping-provided type.
You can check for dtypes by doing issubclass(x, AbstractDtype). For example, issubclass(Float32, AbstractDtype) will pass.
You can check for arrays by doing issubclass(x, AbstractArray). Here, AbstractArray is the base class for all shape-and-dtype specified arrays, e.g. it's a base class for Float32[Array, "foo"].
You can check for pytrees by doing issubclass(x, PyTree). For example, issubclass(PyTree[int], PyTree) will pass.