Skip to content

Runtime type checking¤

(See the FAQ for details on static type checking.)

Runtime type checking synergises beautifully with jax.jit! All shape checks will be performed at trace-time only, and will not impact runtime performance.

Runtime type-checking should be performed using a library like typeguard or beartype.

The types provided by jaxtyping, e.g. Float[Array, "batch channels"], are all compatible with isinstance checks, e.g. isinstance(x, Float[Array, "batch channels"]). This means that jaxtyping should be compatible with all runtime type checkers out-of-the-box.

Some additional context is needed to ensure consistency between multiple argments (i.e. that shapes match up between arrays). For this, you can use either jaxtyping.jaxtyped to add this capability to a single function, or jaxtyping.install_import_hook to add this capability to a whole codebase. If either are too much magic for you, you can safely use neither and have just single-argument type checking.

jaxtyping.jaxtyped(fn) ¤

Used in conjunction with a runtime type checker. Decorate a function with this to have shapes checked for consistency across multiple arguments.

Note that @jaxtyped is applied above the type checker.

Example

# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float32, jaxtyped

# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

# Write your function. @jaxtyped must be applied above @typechecker!
@jaxtyped
@typechecker
def batch_outer_product(x: Float32[Array, "b c1"],
                        y: Float32[Array, "b c2"]
                      ) -> Float32[Array, "b c1 c2"]:
    return x[:, :, None] * y[:, None, :]

Notes for advanced users

Put precisely, all isinstance shape checks are scoped to the thread-local dynamic context of a jaxtyped call. A new dynamic context will allow different dimensions sizes to be bound to the same name. After this new dynamic context is finished then the old one is returned to.

For example, this means you could leave off the @jaxtyped decorator to enforce that this function use the same axes sizes as the function it was called from.

Likewise, this means you can use isinstance checks inside a function body and have them contribute to the same collection of consistency checks performed by a typechecker against its arguments. (Or even forgo a typechecker that analyses arguments, and instead just do your own manual isinstance checks.)

Only isinstance checks that pass will contribute to the store of axis name-size pairs; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, "foo"]).


It can be a lot of effort to add @jaxtyped decorators all over your codebase. (Not to mention that double-decorators everywhere are a bit ugly.)

The easier option is usually to use the import hook.

jaxtyping.install_import_hook(modules: Union[str, collections.abc.Sequence[str]], typechecker: Optional[str]) ¤

Automatically apply @jaxtyped, and optionally a type checker, as decorators.

Usage

from jaxtyping import install_import_hook
# Plus any one of the following:

# decorate @jaxtyped and @typeguard.typechecked
with install_import_hook("foo", "typeguard.typechecked"):
    import foo          # Any module imported inside this `with` block, whose
    import foo.bar      # name begins with the specified string, will
    import foo.bar.qux  # automatically have both `@jaxtyped` and the specified
                        # typechecker applied to all of their functions.

# decorate @jaxtyped and @beartype.beartype
with install_import_hook("foo", "beartype.beartype"):
    ...

# decorate only @jaxtyped (if you want that for some reason)
with install_import_hook("foo", None):
    ...

If you don't like using the with block, the hook can be used without that:

hook = install_import_hook(...):
import ...
hook.uninstall()

The import hook can be applied to multiple packages via

install_import_hook(["foo", "bar.baz"], ...)

The import hook will automatically decorate all functions, and the __init__ method of dataclasses.

If the function already has any decorators on it, then both the @jaxtyped and the typechecker decorators will get added at the bottom of the decorator list, e.g.

@some_other_decorator
@jaxtyped
@beartype.beartype
def foo(...): ...

Arguments::

  • modules: the names of the modules in which to automatically apply @jaxtyped and @typechecked.
  • typechecker: the module and function of the typechecker you want to use, as a string. For example typechecker="typeguard.typechecked", or typechecker="beartype.beartype". You may pass typechecker=None if you do not want to automatically decorate with a typechecker as well.

Returns:

A context manager that uninstalls the hook on exit, or when you call .uninstall().

Example: end-user script
### entry_point.py
from jaxtyping import install_import_hook
with install_import_hook("main", "typeguard.typechecked"):
    import main

### main.py
from jaxtyping import Array, Float32

def f(x: Float32[Array, "batch channels"]):
    ...
Example: writing a library
### __init__.py
from jaxtyping import install_import_hook
with install_import_hook("my_library_name", "beartype.beartype"):
    from .subpackage import foo  # full name is my_library_name.subpackage so
                                 # will be hook'd
    from .another_subpackage import bar  # full name is my_library_name.another_subpackage
                                         # so will be hook'd.
Pytest hook

The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is:

pytest --jaxtyping-packages=foo,bar.baz,beartype.beartype
or in pyproject.toml:
[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype"
or in pytest.ini:
[pytest]
addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype
This example will apply the import hook to all modules whose names start with either foo or bar.baz. The typechecker used in this example is beartype.beartype.

(This is the author's preferred approach to performing runtime type-checking with jaxtyping!)


IPython extension¤

If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic:

import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype  # or any other runtime type checker
Place this at the start of your notebook -- everything that is directly defined in the notebook, after this magic is run, will be hook'd.