Skip to content

Runtime type checking¤

(See the FAQ for details on static type checking.)

Runtime type checking synergises beautifully with jax.jit! All shape checks will be performed only whilst tracing, and will not impact runtime performance.

There are two approaches: either use jaxtyping.jaxtyped to typecheck a single function, or jaxtyping.install_import_hook to typecheck a whole codebase.

In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are beartype and typeguard. (If using typeguard, then specifically the version 2.* series should be used. Later versions -- 3 and 4 -- have some known issues.)

Warning

Avoid using from __future__ import annotations, or stringified type annotations, where possible. These are largely incompatible with runtime type checking. See also this FAQ entry.


jaxtyping.jaxtyped(fn = sentinel, *, typechecker = sentinel) ¤

Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.

Example

# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float, jaxtyped

# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

# Type-check a function
@jaxtyped(typechecker=typechecker)
def batch_outer_product(x: Float[Array, "b c1"],
                        y: Float[Array, "b c2"]
                      ) -> Float[Array, "b c1 c2"]:
    return x[:, :, None] * y[:, None, :]

# Type-check a dataclass
from dataclasses import dataclass

@jaxtyped(typechecker=typechecker)
@dataclass
class MyDataclass:
    x: int
    y: Float[Array, "b c"]

Arguments:

  • fn: The function or dataclass to decorate. In practice if you want to use dataclasses with JAX, then equinox.Module is our recommended approach:

    import equinox as eqx
    
    @jaxtyped(typechecker=typechecker)
    class MyModule(eqx.Module):
        ...
    

  • typechecker: Keyword-only argument: the runtime type-checker to use. This should be a function decorator that will raise an exception if there is a type error, e.g.

    @typechecker
    def f(x: int):
        pass
    
    f("a string is not an integer")  # this line should raise an exception
    
    Common choices are typechecker=beartype.beartype or typechecker=typeguard.typechecked. Can also be set as typechecker=None to skip automatic runtime type-checking, but still support manual isinstance checks inside the function body:
    @jaxtyped(typechecker=None)
    def f(x):
        assert isinstance(x, Float[Array, "batch channel"])
    

Returns:

If fn is a function (including a staticmethod, classmethod, or property), then a wrapped function is returned.

If fn is a dataclass, then fn is returned directly, and additionally its __init__ method is wrapped and modified in-place.

Old syntax

jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax:

@jaxtyped
@typechecker
def f(...): ...
This is still supported, but will now raise a warning recommending the jaxtyped(typechecker=typechecker) syntax discussed above. (Which will produce easier-to-debug error messages: under the hood, the new syntax more carefully manipulates the typechecker so as to determine where a type-check error arises.)

Notes for advanced users

Dynamic contexts:

Put precisely, the axis names in e.g. Float[Array, "batch channels"] and the structure names in e.g. PyTree[int, "T"] are all scoped to the thread-local dynamic context of a jaxtyped-wrapped function. If from within that function we then call another jaxtyped-wrapped function, then a new context is pushed to the stack. The axis sizes and PyTree structures of this inner function will then not be compared against the axis sizes and PyTree structures of the outer function. After the inner function returns then this inner context is popped from the stack, and the previous context is returned to.

isinstance:

Binding of a value against a name is done with an isinstance check, for example isinstance(jnp.zeros((3, 4)), Float[Array, "dim1 dim2"]) will bind dim1=3 and dim2=4. In practice these isinstance checks are usually done by the run-time typechecker typechecker that is supplied as an argument.

This can also be done manually: add isinstance checks inside a function body and they will contribute to the same collection of consistency checks as are performed by the typechecker on the arguments and return values. (Or you can forgo such a typechecker altogether -- i.e. typechecker=None -- and only do your own manual isinstance checks.)

Only isinstance checks that pass will contribute to the store of values; those that fail will not. As such it is safe to write e.g. assert not isinstance(x, Float32[Array, "foo"]).

Decoupling contexts from function calls:

If you would like to call a new function without creating a new dynamic context (and using the same set of axis and structure values), then simply do not add a jaxtyped decorator to your inner function, whilst continuing to perform type-checking in whatever way you prefer.

Conversely, if you would like a new dynamic context without calling a new function, then in addition to the usage discussed above, jaxtyped also supports being used as a context manager, by passing it the string "context":

with jaxtyped("context"):
    assert isinstance(x, Float[Array, "batch channel"])
This is equivalent to placing this code inside a new function wrapped in jaxtyped(typechecker=None). Usage like this is very rare; it's mostly only useful when working at the global scope.


jaxtyping.install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optional[str]) ¤

Automatically apply the @jaxtyped(typechecker=typechecker) decorator to every function and dataclass over a whole codebase.

Usage

from jaxtyping import install_import_hook
# Plus any one of the following:

# decorate `@jaxtyped(typechecker=typeguard.typechecked)`
with install_import_hook("foo", "typeguard.typechecked"):
    import foo          # Any module imported inside this `with` block, whose
    import foo.bar      # name begins with the specified string, will
    import foo.bar.qux  # automatically have both `@jaxtyped` and the specified
                        # typechecker applied to all of their functions and
                        # dataclasses.

# decorate `@jaxtyped(typechecker=beartype.beartype)`
with install_import_hook("foo", "beartype.beartype"):
    ...

# decorate only `@jaxtyped` (if you want that for some reason)
with install_import_hook("foo", None):
    ...

If you don't like using the with block, the hook can be used without that:

hook = install_import_hook(...):
import ...
hook.uninstall()

The import hook can be applied to multiple packages via

install_import_hook(["foo", "bar.baz"], ...)

Arguments::

  • modules: the names of the modules in which to automatically apply @jaxtyped.
  • typechecker: the module and function of the typechecker you want to use, as a string. For example typechecker="typeguard.typechecked", or typechecker="beartype.beartype". You may pass typechecker=None if you do not want to automatically decorate with a typechecker as well.

Returns:

A context manager that uninstalls the hook on exit, or when you call .uninstall().

Example: end-user script

### entry_point.py
from jaxtyping import install_import_hook
with install_import_hook("main", "typeguard.typechecked"):
    import main

### main.py
from jaxtyping import Array, Float32

def f(x: Float32[Array, "batch channels"]):
    ...

Example: writing a library

### __init__.py
from jaxtyping import install_import_hook
with install_import_hook("my_library_name", "beartype.beartype"):
    from .subpackage import foo  # full name is my_library_name.subpackage so
                                 # will be hook'd
    from .another_subpackage import bar  # full name is my_library_name.another_subpackage
                                         # so will be hook'd.

Warning

If a function already has any decorators on it, then @jaxtyped will get added at the bottom of the decorator list, e.g.

@some_other_decorator
@jaxtyped(typechecker=beartype.beartype)
def foo(...): ...
This is to support the common case in which some_other_decorator = jax.custom_jvp etc.

If a class already has any decorators in it, then @jaxtyped will get added to the top of the decorator list, e.g.

@jaxtyped(typechecker=beartype.beartype)
@some_other_decorator
class A:
    ...
This is to support the common case in which some_other_decorator = dataclasses.dataclass.


Pytest hook¤

The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is:

pytest --jaxtyping-packages=foo,bar.baz,beartype.beartype
or in pyproject.toml:
[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype"
or in pytest.ini:
[pytest]
addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype
This example will apply the import hook to all modules whose names start with either foo or bar.baz. The typechecker used in this example is beartype.beartype.

IPython extension¤

If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic:

import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype  # or any other runtime type checker
Place this at the start of your notebook -- everything that is directly defined in the notebook, after this magic is run, will be hook'd.

Other runtime type-checking libraries¤

Beartype and typeguard happen to be the two most popular runtime type-checking libraries (at least at time of writing), but jaxtyping should be compatible with all runtime type checkers out-of-the-box. The runtime type-checking library just needs to provide a type-checking decorator (analgous to beartype.beartype or typeguard.typechecked), and perform isinstance checks against jaxtyping's types.