Skip to content

All of Equinox¤

Equinox is a very small and easy to understand library. (Because it uses JAX abstractions like PyTrees, it doesn't need any real complexity.)

So as the title suggests, this page tells you essentially everything you need to know to use Equinox.

Parameterised functions as PyTrees¤

As we saw on the Getting Started page, Equinox represents parameterised functions as PyTrees.

Example

A neural network is a function parameterised by its weights, biases, etc.

But you can use Equinox to represent any kind of parameterised function! For example Diffrax uses Equinox to represent numerical differential equation solvers.

And now you can JIT/grad/etc. with respect to your model. For example, using a few built-in layers by way of demonstration:

import equinox as eqx
import jax

class MyModule(eqx.Module):
    layers: list
    bias: jax.numpy.ndarray

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       eqx.nn.Linear(8, 8, key=key2),
                       eqx.nn.Linear(8, 2, key=key3)]
        self.bias = jax.numpy.ones(2)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x) + self.bias

@jax.jit
@jax.grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x = jax.random.normal(x_key, (100, 2))
y = jax.random.normal(y_key, (100, 2))
model = MyModule(model_key)
grads = loss(model, x, y)
learning_rate = 0.1
model = jax.tree_util.tree_map(lambda m, g: m - learning_rate * g, model, grads)

Filtering¤

In the previous example, all of the model attributes were Modules and JAX arrays.

But Equinox supports using arbitrary Python objects too! If you choose to include those in your models, then Equinox offers the tools to handle them appropriately around jax.jit and jax.grad. (Which themselves only really work with JAX arrays.)

Example

The activation function in equinox.nn.MLP isn't a JAX array -- it's an arbitrary Python function.

Example

You might have a bool-ean flag in your model-as-a-PyTree, specifying whether to enable some extra piece of behaviour. You might want to treat that as a static_argnum to jax.jit.

If you want to do this, then Equinox offers filtering, as follows.

Create a model

Start off by creating a model just like normal, but with some arbitrary Python objects as part of its parameterisation. In this case, we have jax.nn.relu, which is a Python function.

import equinox as eqx
import functools as ft
import jax

class AnotherModule(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2 = jax.random.split(key)
        self.layers = [eqx.nn.Linear(2, 8, key=key1),
                       jax.nn.relu,
                       eqx.nn.Linear(8, 2, key=key2)]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
x, y = jax.random.normal(x_key, (100, 2)), jax.random.normal(y_key, (100, 2))
model = AnotherModule(model_key)

Option 1: manually filter out anything that isn't JIT/grad-able.

@ft.partial(jax.jit, static_argnums=1)
@jax.grad
def loss(params, static, x, y):
    model = eqx.combine(params, static)
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

params, static = eqx.partition(model, eqx.is_array)
loss(params, static, x, y)

Here, params and static are both instances of AnotherModule: params keeps just the leaves that are JAX arrays; static keeps everything else. Then combine merges the two PyTrees back together after crossing the jax.jit and jax.grad API boundaries.

The choice of eqx.is_array is a filter function: a boolean function specifying whether each leaf should go into params or into static. In this case very simply eqx.is_array(x) returns True for JAX and NumPy arrays.

Option 2: use filtered transformations, which automate the above process for you.

@eqx.filter_jit
@eqx.filter_grad
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

loss(model, x, y)

As a convenience, eqx.filter_jit and eqx.filter_grad wrap filtering and transformation together. It turns out to be really common to only need to filter around JAX transformations.

If your models only use JAX arrays, then eqx.filter_{jit,grad,...} will do exactly the same as jax.{jit,grad,...}. So if you just want to keep things simple, it is safe to just always use eqx.filter_{jit,grad,...}.

Integrates smoothly with JAX¤

Equinox introduces a powerful yet straightforward way to build neural networks, without introducing lots of new notions or tieing you into a framework.

Equinox is all just regular JAX -- PyTrees and transformations. Together, these two pieces allow us to specify complex models in JAX-friendly ways.

Next steps¤

And that's it! That's pretty much everything you need to know about Equinox. Everything you've seen so far should be enough to get started with using the library. Also see the Train RNN example for a fully worked example.

Summary¤

Equinox includes four main things:

  • For building models: equinox.Module.
  • Prebuilt neural network layers: equinox.nn.Linear, equinox.nn.Conv2d, etc.
  • Filtering, and filtered transformations: equinox.filter, equinox.filter_jit etc.
  • Some utilities to help manipulate PyTrees: equinox.tree_at etc.

See also the API reference on the left.

FAQ

One common question: a lot of other libraries introduce custom library.jit etc. operations, specifically to work with library.Module. What makes the filtered transformations of Equinox different?

The answer is that filter transformations are tools that apply to any PyTree. And models just happen to be PyTrees. The filtered transformations and eqx.Module are not coupled together; they are independent tools.