FAQ¤
Optax throwing a TypeError
.¤
Probably you're writing code that looks like
optim = optax.adam(learning_rate)
optim.init(model)
TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0.
This can be fixed by doing
optim.init(eqx.filter(model, eqx.is_inexact_array))
eqx.nn.MLP
).
How are batch dimensions handled?¤
All layers in equinox.nn
are defined to operate on single batch elements, not a whole batch.
To act on a batch, use jax.vmap
. This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions.
For example if x
is an array/tensor of shape (batch_size, input_size)
, then the following PyTorch code:
import torch
linear = torch.nn.Linear(input_size, output_size)
y = linear(x)
is equivalent to the following Equinox code:
import jax
import equinox as eqx
key = jax.random.PRNGKey(seed=0)
linear = eqx.nn.Linear(input_size, output_size, key=key)
y = jax.vmap(linear)(x)
How to share a layer across two parts of a model?¤
Use equinox.nn.Shared
to tie together multiple nodes (layers, weights, ...) in a PyTree.
In particular, don't do something like this:
# Buggy code!
class Module(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(...):
shared_linear = eqx.nn.Linear(...)
self.linear1 = shared_linear
self.linear2 = shared_linear
self.linear1
and self.linear2
are now different.
The reason for this is that in Equinox+JAX, models are PyTrees, not DAGs. (Directed acyclic graphs.) JAX follows a functional-programming-like style, in which the identity of an object (whether tha be a layer, a weight, or whatever) doesn't matter. Only its value matters. (This is known as referential transparency.)
See also the equinox.tree_check
function, which can be ran on a model to check if you have duplicate nodes.
My model is slow...¤
...to train.¤
Make sure you have JIT covering all JAX operations.
Most autodifferentiable programs will have a "numerical bit" (e.g. a training step for your model) and a "normal programming bit" (e.g. saving models to disk). JAX makes this difference explicit. All the numerical work should go inside a single big JIT region, within which all numerical operations are compiled.
See the RNN example as an example of good practice. The whole make_step
function is JIT compiled in one go.
Common mistakes are to put jax.jit
/eqx.filter_jit
on just your loss function, and leave out either (a) computing gradients or (b) applying updates with eqx.apply_updates
.
...to compile.¤
95% of the time, it's because you've done something like this:
@jax.jit
def f(x):
for i in range(100):
x = my_complicated_function(x)
return x
for
loop. (jax.jit
replaces the x
argument with a tracer object that records everything that happens to it -- and this effectively unrolls the loop.) As a result you'll get 100 independent copies of my_complicated_function
, which all get compiled separately.
In this case, a jax.lax.scan
is probably what you want. Likewise it's usually also preferable to rewrite even simple stuff like
x2 = f(x1)
x3 = f(x2)
TypeError: not a valid JAX type.¤
You might be getting an error like
TypeError: Argument '<function ...>' of type <class 'function'> is not a valid JAX type.
import jax
import equinox as eqx
def loss_fn(model, x, y):
return ((model(x) - y) ** 2).mean()
model = eqx.nn.Lambda(lambda x: x)
model = eqx.nn.MLP(2, 2, 2, 2, key=jax.random.PRNGKey(0))
x = jax.numpy.arange(2)
y = x * x
try:
jax.jit(loss_fn)(model, x, y) # error
except TypeError as e:
print(e)
eqx.filter_jit(loss_fn)(model, x, y) # ok
This error happens because a model, when treated as a PyTree, may have leaves that are not JAX types (such as functions). It only makes sense to trace arrays. Filtering is used to handle this.
Instead of jax.jit
, use equinox.filter_jit
. Likewise for other transformations.
How to mark arrays as non-trainable? (Like PyTorch's buffers?)¤
This can be done by using jax.lax.stop_gradient
:
class Model(eqx.Module):
buffer: Array
param: Array
def __call__(self, x):
return self.param * x + jax.lax.stop_gradient(self.buffer)
I think my function is being recompiled each time it is run.¤
You can check each time your function is compiled by adding a print statement:
@eqx.filter_jit
def your_function(x, y, z):
print("Compiling!")
... # rest of your code here
A function will be recompiled every time the shape or dtype of its arrays changes, or if any of its static (non-array) inputs change (as measured by __eq__
).
If you want to check which argument is causing an undesired recompilation, then this can be done by checking each argument in turn:
@eqx.filter_jit
def check_arg(arg, name):
print(f"Argument {name} is triggering a compile.")
for step, (x, y, z) in enumerate(...): # e.g. a training loop
print(f"Step is {step}")
check_arg(x, "x")
check_arg(y, "y")
check_arg(z, "z")
your_function(x, y, z)
Step is 0
Argument x is triggering a compile.
Argument y is triggering a compile.
Argument z is triggering a compile.
Step is 1
Argument y is triggering a compile.
Step is 2
Argument y is triggering a compile.
...
check_arg
-- this will be the one that is triggering a recompilation of your_function
as well!
How does Equinox compare to...?¤
...PyTorch?¤
JAX+Equinox is usually faster than PyTorch (a stronger JIT compiler), and more featureful (e.g. supporting jit-of-vmap, forward-mode autolinearisation, and autoparallelism).
For those doing scientific computing or scientific ML, then JAX+Equinox also has a much stronger ecosystem. For example, PyTorch no longer has a library for solving differential equations (torchdiffeq is unmaintained). Meanwhile, JAX has Diffrax.
Both JAX+Equinox and PyTorch are roughly equally easy to use. PyTorch tends to be a easier for new users (e.g. it's closer to being "Python as normal", and there's less functional programming), whilst JAX+Equinox generally supports advanced use-cases more cleanly (e.g. PyTorch has multiple JIT compilers each with their own quirks -- torch.{fx, jit.script, jit.trace, compile, _dynamo, ...}
-- whilst JAX+Equinox just has the one).
PyTorch is older, and enjoys broader adoption -- it's generally easier to find developers for PyTorch, or off-the-shelf model implementations using it.
...Keras?¤
These are two very different libraries, with very different target audiences. Keras is great for plug-and-play building of models -- it's often compared to using Lego. This makes it a convenient framework for standing up neural networks quickly. Equinox is much lower level: it tries to support more general use-cases (e.g. its downstream scientific ecosystem), but usually requires greater proficiency with numerical computing / software development / machine learning.
...Flax?¤
- Flax introduces multiple new abstractions (
flax.linen.Module
,flax.linen.Variable
,Module.setup
vsflax.linen.compact
,flax.struct.dataclass
, etc.). Equinox tries to avoid adding new abstractions to core JAX; everything is always just a PyTree. - Flax is a DSL: it is generally incompatible with non-Flax code, and requires using wrapped
flax.linen.{vmap, scan, ...}
rather than the nativejax.{vmap, ...}
. In contrast, Equinox allows you to use native JAX operations and aims to be compatible with arbitrary JAX code. - Bound methods of
eqx.Module
are just PyTrees. In Flax this isn't the case -- passing around bound methods will either result in errors or recompilations, depending what you do. Likewise,eqx.Module
handles inheritance correctly, including propagating metadata like docstrings. The equivalentflax.struct.dataclass
silently misbehaves. Overall Equinox seems to have fewer footguns. - Equinox offers several advanced features (like runtime errors or PyTree manipulation) not found in other libraries.
See also the Equinox paper.
...Julia?¤
The Julia ecosystem has historically been buggy.
At time of writing, Julia does not yet have a robust autodifferentiation system. For example, it has multiple competing implementations -- both Diffractor.jl and ForwardDiff.jl for forward-mode autodifferentiation, and all of Tracker.jl, Zygote.jl, Enzyme.jl, ReverseDiff.jl for reverse-mode autodifferentiation. It does not yet support higher-order autodifferentiation robustly. In contrast, JAX+Equinox use a single strong autodifferentiation system.
However, note that JAX+Equinox don't try to offer a completely general programming model: they are optimised for arrays and linear algebra. (Essentially, the sorts of things you use NumPy for.) They're not designed for e.g. a branch-and-bound combinatorial optimisation algorithm, and for these purposes Julia will be superior.
Julia is often a small amount faster on microbenchmarks on CPUs. JAX+Equinox supports running on TPUs, whilst Julia generally does not.
You're obviously biased! Are the above comparisons fair?
Seriously, we think they're fair! Nonetheless all of the above approaches have their adherents, so it seems like all of these approaches are doing something right. So if you're already happily using one of them for your current project... then keep using them. (Don't rewrite things for no reason.) But conversely, we'd invite you to try Equinox for your next project. :)
For what it's worth, if you have the time to learn (e.g. you're a grad student), then we'd strongly recommend trying all of the above. All of these libraries have made substantial innovations, and have all made substantially moved the numerical computing space forward. Equinox deliberately takes inspiration from them!