Runtime type checking¤
(See the FAQ for details on static type checking.)
Runtime type checking synergises beautifully with
jax.jit! All shape checks will be performed at trace-time only, and will not impact runtime performance.
The types provided by
Float[Array, "batch channels"], are all compatible with
isinstance checks, e.g.
isinstance(x, Float[Array, "batch channels"]). This means that jaxtyping should be compatible with all runtime type checkers out-of-the-box.
Some additional context is needed to ensure consistency between multiple argments (i.e. that shapes match up between arrays). For this, you can use either
jaxtyping.jaxtyped to add this capability to a single function, or
jaxtyping.install_import_hook to add this capability to a whole codebase. If either are too much magic for you, you can safely use neither and have just single-argument type checking.
Used in conjunction with a runtime type checker. Decorate a function with this to have shapes checked for consistency across multiple arguments.
@jaxtyped is applied above the type checker.
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping` from jaxtyping import Array, Float32, jaxtyped # Use your favourite typechecker: usually one of the two lines below. from typeguard import typechecked as typechecker from beartype import beartype as typechecker # Write your function. @jaxtyped must be applied above @typechecker! @jaxtyped @typechecker def batch_outer_product(x: Float32[Array, "b c1"], y: Float32[Array, "b c2"] ) -> Float32[Array, "b c1 c2"]: return x[:, :, None] * y[:, None, :]
Notes for advanced users
Put precisely, all
isinstance shape checks are scoped to the thread-local dynamic
context of a
jaxtyped call. A new dynamic context will allow different dimensions
sizes to be bound to the same name. After this new dynamic context is finished
then the old one is returned to.
For example, this means you could leave off the
@jaxtyped decorator to enforce
that this function use the same axes sizes as the function it was called from.
Likewise, this means you can use
isinstance checks inside a function body
and have them contribute to the same collection of consistency checks performed
by a typechecker against its arguments. (Or even forgo a typechecker that analyses
arguments, and instead just do your own manual
isinstance checks that pass will contribute to the store of axis name-size
pairs; those that fail will not. As such it is safe to write e.g.
assert not isinstance(x, Float32[Array, "foo"]).
It can be a lot of effort to add
@jaxtyped decorators all over your codebase.
(Not to mention that double-decorators everywhere are a bit ugly.)
The easier option is usually to use the import hook.
jaxtyping.install_import_hook(modules: Union[str, collections.abc.Sequence[str]], typechecker: Optional[str])
@jaxtyped, and optionally a type checker, as decorators.
from jaxtyping import install_import_hook # Plus any one of the following: # decorate @jaxtyped and @typeguard.typechecked with install_import_hook("foo", "typeguard.typechecked"): import foo # Any module imported inside this `with` block, whose import foo.bar # name begins with the specified string, will import foo.bar.qux # automatically have both `@jaxtyped` and the specified # typechecker applied to all of their functions. # decorate @jaxtyped and @beartype.beartype with install_import_hook("foo", "beartype.beartype"): ... # decorate only @jaxtyped (if you want that for some reason) with install_import_hook("foo", None): ...
If you don't like using the
with block, the hook can be used without that:
hook = install_import_hook(...): import ... hook.uninstall()
The import hook can be applied to multiple packages via
install_import_hook(["foo", "bar.baz"], ...)
The import hook will automatically decorate all functions, and the
If the function already has any decorators on it, then both the
@jaxtyped and the
typechecker decorators will get added at the bottom of the decorator list, e.g.
@some_other_decorator @jaxtyped @beartype.beartype def foo(...): ...
modules: the names of the modules in which to automatically apply
typechecker: the module and function of the typechecker you want to use, as a string. For example
typechecker="beartype.beartype". You may pass
typechecker=Noneif you do not want to automatically decorate with a typechecker as well.
A context manager that uninstalls the hook on exit, or when you call
Example: end-user script
### entry_point.py from jaxtyping import install_import_hook with install_import_hook("main", "typeguard.typechecked"): import main ### main.py from jaxtyping import Array, Float32 def f(x: Float32[Array, "batch channels"]): ...
Example: writing a library
### __init__.py from jaxtyping import install_import_hook with install_import_hook("my_library_name", "beartype.beartype"): from .subpackage import foo # full name is my_library_name.subpackage so # will be hook'd from .another_subpackage import bar # full name is my_library_name.another_subpackage # so will be hook'd.
The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is:
[tool.pytest.ini_options] addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype"
[pytest] addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype
bar.baz. The typechecker used in this example is
(This is the author's preferred approach to performing runtime type-checking with jaxtyping!)
If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic:
import jaxtyping %load_ext jaxtyping %jaxtyping.typechecker beartype.beartype # or any other runtime type checker