Runtime type checking¤
(See the FAQ for details on static type checking.)
Runtime type checking synergises beautifully with jax.jit
! All shape checks will be performed at trace-time only, and will not impact runtime performance.
Runtime type-checking should be performed using a library like typeguard or beartype.
The types provided by jaxtyping
, e.g. Float[Array, "batch channels"]
, are all compatible with isinstance
checks, e.g. isinstance(x, Float[Array, "batch channels"])
. This means that jaxtyping should be compatible with all runtime type checkers out-of-the-box.
Some additional context is needed to ensure consistency between multiple argments (i.e. that shapes match up between arrays). For this, you can use either jaxtyping.jaxtyped
to add this capability to a single function, or jaxtyping.install_import_hook
to add this capability to a whole codebase. If either are too much magic for you, you can safely use neither and have just single-argument type checking.
jaxtyping.jaxtyped(fn)
¤
Used in conjunction with a runtime type checker. Decorate a function with this to have shapes checked for consistency across multiple arguments.
Note that @jaxtyped
is applied above the type checker.
Example
# Import both the annotation and the `jaxtyped` decorator from `jaxtyping`
from jaxtyping import Array, Float32, jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker
# Write your function. @jaxtyped must be applied above @typechecker!
@jaxtyped
@typechecker
def batch_outer_product(x: Float32[Array, "b c1"],
y: Float32[Array, "b c2"]
) -> Float32[Array, "b c1 c2"]:
return x[:, :, None] * y[:, None, :]
Notes for advanced users
Put precisely, all isinstance
shape checks are scoped to the thread-local dynamic
context of a jaxtyped
call. A new dynamic context will allow different dimensions
sizes to be bound to the same name. After this new dynamic context is finished
then the old one is returned to.
For example, this means you could leave off the @jaxtyped
decorator to enforce
that this function use the same axes sizes as the function it was called from.
Likewise, this means you can use isinstance
checks inside a function body
and have them contribute to the same collection of consistency checks performed
by a typechecker against its arguments. (Or even forgo a typechecker that analyses
arguments, and instead just do your own manual isinstance
checks.)
Only isinstance
checks that pass will contribute to the store of axis name-size
pairs; those that fail will not. As such it is safe to write e.g.
assert not isinstance(x, Float32[Array, "foo"])
.
It can be a lot of effort to add @jaxtyped
decorators all over your codebase.
(Not to mention that double-decorators everywhere are a bit ugly.)
The easier option is usually to use the import hook.
jaxtyping.install_import_hook(modules: Union[str, collections.abc.Sequence[str]], typechecker: Optional[str])
¤
Automatically apply @jaxtyped
, and optionally a type checker, as decorators.
Usage
from jaxtyping import install_import_hook
# Plus any one of the following:
# decorate @jaxtyped and @typeguard.typechecked
with install_import_hook("foo", "typeguard.typechecked"):
import foo # Any module imported inside this `with` block, whose
import foo.bar # name begins with the specified string, will
import foo.bar.qux # automatically have both `@jaxtyped` and the specified
# typechecker applied to all of their functions.
# decorate @jaxtyped and @beartype.beartype
with install_import_hook("foo", "beartype.beartype"):
...
# decorate only @jaxtyped (if you want that for some reason)
with install_import_hook("foo", None):
...
If you don't like using the with
block, the hook can be used without that:
hook = install_import_hook(...):
import ...
hook.uninstall()
The import hook can be applied to multiple packages via
install_import_hook(["foo", "bar.baz"], ...)
The import hook will automatically decorate all functions, and the __init__
method
of dataclasses.
If the function already has any decorators on it, then both the @jaxtyped
and the
typechecker decorators will get added at the bottom of the decorator list, e.g.
@some_other_decorator
@jaxtyped
@beartype.beartype
def foo(...): ...
Arguments::
modules
: the names of the modules in which to automatically apply@jaxtyped
and@typechecked
.typechecker
: the module and function of the typechecker you want to use, as a string. For exampletypechecker="typeguard.typechecked"
, ortypechecker="beartype.beartype"
. You may passtypechecker=None
if you do not want to automatically decorate with a typechecker as well.
Returns:
A context manager that uninstalls the hook on exit, or when you call .uninstall()
.
Example: end-user script
### entry_point.py
from jaxtyping import install_import_hook
with install_import_hook("main", "typeguard.typechecked"):
import main
### main.py
from jaxtyping import Array, Float32
def f(x: Float32[Array, "batch channels"]):
...
Example: writing a library
### __init__.py
from jaxtyping import install_import_hook
with install_import_hook("my_library_name", "beartype.beartype"):
from .subpackage import foo # full name is my_library_name.subpackage so
# will be hook'd
from .another_subpackage import bar # full name is my_library_name.another_subpackage
# so will be hook'd.
Pytest hook
The import hook can be installed at test-time only, as a pytest hook. From the command line the syntax is:
pytest --jaxtyping-packages=foo,bar.baz,beartype.beartype
pyproject.toml
:
[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=foo,bar.baz,beartype.beartype"
pytest.ini
:
[pytest]
addopts = --jaxtyping-packages=foo,bar.baz,beartype.beartype
foo
or bar.baz
. The typechecker used in this example is
beartype.beartype
.
(This is the author's preferred approach to performing runtime type-checking with jaxtyping!)
IPython extension¤
If you are running in an IPython environment (for example a Jupyter or Colab notebook), then the jaxtyping hook can be automatically ran via a custom magic:
import jaxtyping
%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype # or any other runtime type checker