Runtime type checking¤
(See the FAQ for details on static type checking.)
Runtime type checking synergises beautifully with jax.jit
! All shape checks will be performed only whilst tracing, and will not impact runtime performance.
There are two approaches: either use jaxtyping.jaxtyped
to typecheck a single function, or jaxtyping.install_import_hook
to typecheck a whole codebase.
In either case, the actual business of checking types is performed with the help of a runtime type-checking library. The two most popular are beartype and typeguard. (If using typeguard, then specifically the version 2.*
series should be used. Later versions -- 3
and 4
-- have some known issues.)
Warning
Avoid using from __future__ import annotations
, or stringified type annotations, where possible. These are largely incompatible with runtime type checking. See also this FAQ entry.
jaxtyping.jaxtyped(fn = sentinel, *, typechecker = sentinel)
¤
Decorate a function with this to perform runtime type-checking of its arguments and return value. Decorate a dataclass to perform type-checking of its attributes.
Example
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float, jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker
# Type-check a function
@jaxtyped(typechecker=typechecker)
def batch_outer_product(x: Float[Array, "b c1"],
y: Float[Array, "b c2"]
) -> Float[Array, "b c1 c2"]:
return x[:, :, None] * y[:, None, :]
# Type-check a dataclass
from dataclasses import dataclass
@jaxtyped(typechecker=typechecker)
@dataclass
class MyDataclass:
x: int
y: Float[Array, "b c"]
Arguments:
-
fn
: The function or dataclass to decorate. In practice if you want to use dataclasses with JAX, thenequinox.Module
is our recommended approach:import equinox as eqx @jaxtyped(typechecker=typechecker) class MyModule(eqx.Module): ...
-
typechecker
: Keyword-only argument: the runtime type-checker to use. This should be a function decorator that will raise an exception if there is a type error, e.g.Common choices are@typechecker def f(x: int): pass f("a string is not an integer") # this line should raise an exception
typechecker=beartype.beartype
ortypechecker=typeguard.typechecked
. Can also be set astypechecker=None
to skip automatic runtime type-checking, but still support manualisinstance
checks inside the function body:@jaxtyped(typechecker=None) def f(x): assert isinstance(x, Float[Array, "batch channel"])
Returns:
If fn
is a function (including a staticmethod
, classmethod
, or property
),
then a wrapped function is returned.
If fn
is a dataclass, then fn
is returned directly, and additionally its
__init__
method is wrapped and modified in-place.
Old syntax
jaxtyping previously (before v0.2.24) recommended using this double-decorator syntax:
@jaxtyped
@typechecker
def f(...): ...
jaxtyped(typechecker=typechecker)
syntax discussed above. (Which will produce
easier-to-debug error messages: under the hood, the new syntax more carefully
manipulates the typechecker so as to determine where a type-check error arises.)
Notes for advanced users
Dynamic contexts:
Put precisely, the axis names in e.g. Float[Array, "batch channels"]
and the
structure names in e.g. PyTree[int, "T"]
are all scoped to the thread-local
dynamic context of a jaxtyped
-wrapped function. If from within that function
we then call another jaxtyped
-wrapped function, then a new context is pushed
to the stack. The axis sizes and PyTree structures of this inner function will
then not be compared against the axis sizes and PyTree structures of the outer
function. After the inner function returns then this inner context is popped
from the stack, and the previous context is returned to.
isinstance:
Binding of a value against a name is done with an isinstance
check, for
example isinstance(jnp.zeros((3, 4)), Float[Array, "dim1 dim2"])
will bind
dim1=3
and dim2=4
. In practice these isinstance
checks are usually done by
the run-time typechecker typechecker
that is supplied as an argument.
This can also be done manually: add isinstance
checks inside a function body
and they will contribute to the same collection of consistency checks as are
performed by the typechecker on the arguments and return values. (Or you can
forgo such a typechecker altogether -- i.e. typechecker=None
-- and only do
your own manual isinstance
checks.)
Only isinstance
checks that pass will contribute to the store of values; those
that fail will not. As such it is safe to write e.g.
assert not isinstance(x, Float32[Array, "foo"])
.
Decoupling contexts from function calls:
If you would like to call a new function without creating a new
dynamic context (and using the same set of axis and structure values), then
simply do not add a jaxtyped
decorator to your inner function, whilst
continuing to perform type-checking in whatever way you prefer.
Conversely, if you would like a new dynamic context without calling a new
function, then in addition to the usage discussed above, jaxtyped
also
supports being used as a context manager, by passing it the string "context"
:
with jaxtyped("context"):
assert isinstance(x, Float[Array, "batch channel"])
jaxtyped(typechecker=None)
. Usage like this is very rare; it's mostly only
useful when working at the global scope.
jaxtyping.install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optional[str])
¤
Automatically apply the @jaxtyped(typechecker=typechecker)
decorator to every
function and dataclass over a whole codebase.
Usage
from jaxtyping import install_import_hook
# Plus any one of the following:
# decorate `@jaxtyped(typechecker=typeguard.typechecked)`
with install_import_hook("foo", "typeguard.typechecked"):
import foo # Any module imported inside this `with` block, whose
import foo.bar # name begins with the specified string, will
import foo.bar.qux # automatically have both `@jaxtyped` and the specified
# typechecker applied to all of their functions and
# dataclasses.
# decorate `@jaxtyped(typechecker=beartype.beartype)`
with install_import_hook("foo", "beartype.beartype"):
...
# decorate only `@jaxtyped` (if you want that for some reason)
with install_import_hook("foo", None):
...
If you don't like using the with
block, the hook can be used without that:
hook = install_import_hook(...)
import ...
hook.uninstall()
The import hook can be applied to multiple packages via
install_import_hook(["foo", "bar.baz"], ...)
Arguments:
modules
: the names of the modules in which to automatically apply@jaxtyped
.typechecker
: the module and function of the typechecker you want to use, as a string. For exampletypechecker="typeguard.typechecked"
, ortypechecker="beartype.beartype"
. You may passtypechecker=None
if you do not want to automatically decorate with a typechecker as well.
Returns:
A context manager that uninstalls the hook on exit, or when you call .uninstall()
.
Example: end-user script
### entry_point.py
from jaxtyping import install_import_hook
with install_import_hook("main", "typeguard.typechecked"):
import main
### main.py
from jaxtyping import Array, Float32
def f(x: Float32[Array, "batch channels"]):
...
Example: writing a library
### __init__.py
from jaxtyping import install_import_hook
with install_import_hook("my_library_name", "beartype.beartype"):
from .subpackage import foo # full name is my_library_name.subpackage so
# will be hook'd
from .another_subpackage import bar # full name is my_library_name.another_subpackage
# so will be hook'd.
Warning
If a function already has any decorators on it, then @jaxtyped
will get added
at the bottom of the decorator list, e.g.
@some_other_decorator
@jaxtyped(typechecker=beartype.beartype)
def foo(...): ...
some_other_decorator = jax.custom_jvp
etc.
If a class already has any decorators in it, then @jaxtyped
will get added to
the top of the decorator list, e.g.
@jaxtyped(typechecker=beartype.beartype)
@some_other_decorator
class A:
...
some_other_decorator = dataclasses.dataclass
.
Pytest hook¤
The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is:
pytest --jaxtyping-packages=foo,bar.baz,beartype.beartype
pyproject.toml
:
[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype"
pytest.ini
:
[pytest]
addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype
foo
or bar.baz
. The typechecker used in this example is beartype.beartype
.
IPython extension¤
If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic:
import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype # or any other runtime type checker
Other runtime type-checking libraries¤
Beartype and typeguard happen to be the two most popular runtime type-checking libraries (at least at time of writing), but jaxtyping should be compatible with all runtime type checkers out-of-the-box. The runtime type-checking library just needs to provide a type-checking decorator (analgous to beartype.beartype
or typeguard.typechecked
), and perform isinstance
checks against jaxtyping's types.