Is jaxtyping compatible with static type checkers like
There is partial support for these. An annotation of the form
dtype[array, shape] should be treated as just
array by a static type checker. Unfortunately full dtype/shape checking is beyond the scope of what static type checking is currently capable of.
(Note that at time of writing,
pytype has a bug in that
dtype[array, shape] is sometimes treated as
Any rather than
pyright both work fine.)
How does jaxtyping interact with
jax.jit synergise beautifully.
When calling JAX operations wrapped in a
jax.jit, then the dtype/shape-checking will happen at trace time. (When JAX traces your function prior to compiling it.) The actual compiled code does not have any dtype/shape-checking, and will therefore still be just as fast as before!
flake8 or Ruff are throwing an error.¤
In type annotations, strings are used for two different things. Sometimes they're strings. Sometimes they're "forward references", used to refer to a type that will be defined later.
Some tooling in the Python ecosystem assumes that only the latter is true, and will throw spurious errors if you try to use a string just as a string (like we do).
In the case of
flake8, or Ruff, this can be resolved. Multi-dimensional arrays (e.g.
Float32[Array, "b c"]) will throw a very unusual error (F722, syntax error in forward annotation), so you can safely just disable this particular error globally. Uni-dimensional arrays (e.g.
Float32[Array, "x"]) will throw an error that's actually useful (F821, undefined name), so instead of disabling this globally, you should instead prepend a space to the start of your shape, e.g.
Float32[Array, " x"].
jaxtyping will treat this in the same way, whilst
flake8 will now throw an F722 error that you can disable as before.
The intention of PEP 646 was to make it possible for static type checkers to perform shape checks of arrays. Unfortunately, this still isn't yet practical, so jaxtyping deliberately does not use this. (Yet?)
The real problem is that Python's static typing ecosystem is a complicated collection of edge cases. Many of them block ML/scientific computing in particular. For example:
The static type system is intrinsically not expressive enough to describe operations like concatenation, stacking, or broadcasting.
Axes have to be lifted to type-level variables. Meanwhile the approach taken in libraries like
jaxtypingand TorchTyping is to use value-level variables for types: because that's what the underlying JAX, PyTorch etc. libraries use! As such, making a static type checker work with these libraries would require either fundamentally rewriting these libraries, or exhaustively maintaining type stubs for them, and would still require a
typing.castany time you use anything unstubbed (e.g. any third party library, or part of your codebase you haven't typed yet). This is a huge maintenance burden.
Static type checkers have a variety of bugs that affect this use case.
pyrightdoesn't support genericised subprotocols. etc.
Variadic generics exist. Variadic protocols do not. (It's not clear that these were contemplated.)
The syntax for static typing is a little verbose. You have to write things like
Array[Float32, Unpack[AnyShape], Literal, Height, Width]instead of
Float32[Array, "... 3 height width"].
The underlying type system has flaws.
The numeric tower is broken;
int is not a number;
virtual base classes don't work;
complex lies about having comparison operations, so type checkers have to lie about that lie in order to remove them again;
typing.*don't work with
co/contra-variance are baked into containers (not specified at use-time);
dictis variadic despite... not being variadic;
bool is a subclass of int (!);
... etc. etc.