All of Equinox¤
Equinox is a small and easy to understand library. So as the title suggests, this page tells you essentially everything you need to know to use Equinox.
1. Models as PyTrees¤
What's a PyTree?
PyTrees are what JAX calls nested collections of tuples, lists, and dicts. (And any custom-registered PyTree nodes.) The "leaves" of the tree can be anything at all: JAX/NumPy arrays, floats, functions, etc. Most JAX operations will accept either (a) arbitrary PyTrees; (b) PyTrees with just JAX/NumPy arrays as the leaves; (c) PyTrees without any JAX/NumPy arrays as the leaves.
As we saw on the Getting Started page, Equinox offers the ability to represent models as PyTrees. This is one of Equinox's main features.
Once we've done so, we'll be able to JIT/grad/etc. with respect to the model. For example, using a few built-in layers by way of demonstration, here's a small neural network:
import equinox as eqx
import jax
class NeuralNetwork(eqx.Module):
layers: list
extra_bias: jax.Array
def __init__(self, key):
key1, key2, key3 = jax.random.split(key, 3)
# These contain trainable parameters.
self.layers = [eqx.nn.Linear(2, 8, key=key1),
eqx.nn.Linear(8, 8, key=key2),
eqx.nn.Linear(8, 2, key=key3)]
# This is also a trainable parameter.
self.extra_bias = jax.numpy.ones(2)
def __call__(self, x):
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))
return self.layers[-1](x) + self.extra_bias
@jax.jit # compile this function to make it run fast.
@jax.grad # differentiate all floating-point arrays in `model`.
def loss(model, x, y):
pred_y = jax.vmap(model)(x) # vectorise the model over a batch of data
return jax.numpy.mean((y - pred_y) ** 2) # L2 loss
x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
# Example data
x = jax.random.normal(x_key, (100, 2))
y = jax.random.normal(y_key, (100, 2))
model = NeuralNetwork(model_key)
# Compute gradients
grads = loss(model, x, y)
# Perform gradient descent
learning_rate = 0.1
new_model = jax.tree_util.tree_map(lambda m, g: m - learning_rate * g, model, grads)
In this example, model = NeuralNetwork(...)
is the overall PyTree. Nested within that is model.layers
and model.extra_bias
. The former is also a PyTree, containing three eqx.nn.Linear
layers at model.layers[0]
, model.layers[1]
, and model.layers[2]
. Each of these are also PyTrees, containing their weights and biases, e.g. model.layers[0].weight
.
2. Filtering¤
In the previous example, all of the leaves were JAX arrays. This made things simple, because jax.jit
and jax.grad
-decorated functions require that all of their inputs are PyTrees of arrays.
Equinox goes further, and supports using arbitrary Python objects for its leaves. For example, we might like to make our activation function part of the PyTree (rather than just hardcoding it as above). The activation function will just be some arbitrary Python function, and this isn't an array. Another common example is having a bool
-ean flag in your model, which specifies whether to enable some extra piece of behaviour.
To support this, then Equinox offers filtering, as follows.
Create a model
Start off by creating a model just like normal, now with some arbitrary Python objects as part of its PyTree structure. In this case, we have jax.nn.relu
, which is a Python function.
import equinox as eqx
import functools as ft
import jax
class NeuralNetwork2(eqx.Module):
layers: list
def __init__(self, key):
key1, key2 = jax.random.split(key)
self.layers = [eqx.nn.Linear(2, 8, key=key1),
jax.nn.relu,
eqx.nn.Linear(8, 2, key=key2)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x, y = jax.random.normal(x_key, (100, 2)), jax.random.normal(y_key, (100, 2))
model = NeuralNetwork2(model_key)
Option 1: use eqx.{partition,combine}
@ft.partial(jax.jit, static_argnums=1) # `static` must be a PyTree of non-arrays.
@jax.grad # differentiates with respect to `params`, as it is the first argument
def loss(params, static, x, y):
model = eqx.combine(params, static)
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
params, static = eqx.partition(model, eqx.is_array)
loss(params, static, x, y)
Here, we split our model PyTree into two pieces. params
and static
are both instances of NeuralNetwork2
. params
keeps just the leaves that are arrays; static
keeps everything else. Then combine
merges the two PyTrees back together after crossing the jax.jit
and jax.grad
API boundaries.
The choice of eqx.is_array
is a filter function: a boolean function specifying whether each leaf should go into params
or into static
. In this case very simply eqx.is_array(x)
returns True
for JAX and NumPy arrays, and False
for everything else.
Option 2: use filtered transformations
@eqx.filter_jit
@eqx.filter_grad
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
loss(model, x, y)
As a convenience, eqx.filter_jit
and eqx.filter_grad
wrap filtering and transformation together. It turns out to be really common to only need to filter around JAX transformations.
If your models only use JAX arrays, then eqx.filter_{jit,grad,...}
will do exactly the same as jax.{jit,grad,...}
. So if you just want to keep things simple, it is safe to just always use eqx.filter_{jit,grad,...}
.
Both approaches are equally valid. Some people prefer the shorter syntax of the filtered transformations. Some people prefer to explicitly see the jax.{jit,grad,...}
operations directly.
3. PyTree manipulation routines.¤
Equinox clearly places a heavy focus on PyTrees! As such, it's quite common to need to perform operations on PyTrees. Whilst many common operations are already provided by JAX (for example, jax.tree_util.tree_map
will apply an operation to every leaf of a PyTree), Equinox additionally offers some extra features. For example, eqx.tree_at
mutates a particular leaf or leaves of a PyTree.
4. Advanced goodies.¤
Finally, Equinox offers a number of more advanced goodies, like serialisation, debugging tools, and runtime errors. We won't discuss them here, but check out the API reference on the left.
5. Summary¤
Equinox integrates smoothly with JAX
Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tying you into a framework. Indeed Equinox is a library, not a framework -- this means that anything you write in Equinox is fully compatible with anything else in the JAX ecosystem.
Equinox is all just regular JAX: PyTrees and transformations. Together, these two pieces allow us to specify complex models in JAX-friendly ways.
API reference
- For building models:
equinox.Module
. - Prebuilt neural network layers:
equinox.nn.Linear
,equinox.nn.Conv2d
, etc. - Filtered transformations:
equinox.filter_jit
etc. - Tools for PyTree manipulation:
equinox.partition
, etc. - Advanced goodies: serialisation, debugging tools, runtime errors, etc.
See the API reference on the left.
Next steps
And that's it! That's pretty much everything you need to know about Equinox. Everything you've seen so far should be enough to get started with using the library. Also see the Train RNN example for a fully worked example.