FAQ¤
Is jaxtyping compatible with static type checkers like mypy
/pyright
/pytype
?¤
There is partial support for these. An annotation of the form dtype[array, shape]
should be treated as just array
by a static type checker. Unfortunately full dtype/shape checking is beyond the scope of what static type checking is currently capable of.
(Note that at time of writing, pytype
has a bug in that dtype[array, shape]
is sometimes treated as Any
rather than array
. mypy
and pyright
both work fine.)
How does jaxtyping interact with jax.jit
?¤
jaxtyping and jax.jit
synergise beautifully.
When calling JAX operations wrapped in a jax.jit
, then the dtype/shape-checking will happen at trace time. (When JAX traces your function prior to compiling it.) The actual compiled code does not have any dtype/shape-checking, and will therefore still be just as fast as before!
flake8
or Ruff are throwing an error.¤
In type annotations, strings are used for two different things. Sometimes they're strings. Sometimes they're "forward references", used to refer to a type that will be defined later.
Some tooling in the Python ecosystem assumes that only the latter is true, and will throw spurious errors if you try to use a string just as a string (like we do).
In the case of flake8
, or Ruff, this can be resolved. Multi-dimensional arrays (e.g. Float32[Array, "b c"]
) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g. Float32[Array, "x"]
) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g. Float32[Array, " x"]
. jaxtyping
will treat this in the same way, whilst flake8
will now throw an F722 error that you can disable as before.
Dataclass annotations aren't being checked properly.¤
Stringified dataclass annotations, e.g.
@dataclass()
class Foo:
x: "int"
from __future__ import annotations
. (You should essentially never use the latter, it is largely incompatible with runtime type checking and as such is being replaced in Python 3.13.)
Partially stringified dataclass annotations, e.g.
@dataclass()
class Foo:
x: tuple["int"]
Does jaxtyping use PEP 646 (variadic generics)?¤
The intention of PEP 646 was to make it possible for static type checkers to perform shape checks of arrays. Unfortunately, this still isn't yet practical, so jaxtyping deliberately does not use this. (Yet?)
The real problem is that Python's static typing ecosystem is a complicated collection of edge cases. Many of them block ML/scientific computing in particular. For example:
-
The static type system is intrinsically not expressive enough to describe operations like concatenation, stacking, or broadcasting.
-
Axes have to be lifted to type-level variables. Meanwhile the approach taken in libraries like
jaxtyping
and TorchTyping is to use value-level variables for types: because that's what the underlying JAX, PyTorch etc. libraries use! As such, making a static type checker work with these libraries would require either fundamentally rewriting these libraries, or exhaustively maintaining type stubs for them, and would still require atyping.cast
any time you use anything unstubbed (e.g. any third party library, or part of your codebase you haven't typed yet). This is a huge maintenance burden. -
Static type checkers have a variety of bugs that affect this use case.
mypy
doesn't supportProtocol
s correctly.pyright
doesn't support genericised subprotocols. etc. -
Variadic generics exist. Variadic protocols do not. (It's not clear that these were contemplated.)
-
The syntax for static typing is a little verbose. You have to write things like
Array[Float32, Unpack[AnyShape], Literal[3], Height, Width]
instead ofFloat32[Array, "... 3 height width"]
. -
The underlying type system has flaws.
The numeric tower is broken;
int is not a number;
virtual base classes don't work;
complex lies about having comparison operations, so type checkers have to lie about that lie in order to remove them again;
typing.*
don't work withisinstance
;
co/contra-variance are baked into containers (not specified at use-time);
dict
is variadic despite... not being variadic;
bool is a subclass of int (!);
... etc. etc.