Skip to content

Stateful operations (e.g. BatchNorm)ยค

Some layers, such as equinox.nn.BatchNorm are sometimes called "stateful": this refers to the fact that they take an additional input (in the case of BatchNorm, the running statistics) and return an additional output (the updated running statistics).

This just means that we need to plumb an extra input and output through our models. This example demonstrates both equinox.nn.BatchNorm and equinox.nn.SpectralNorm.

See also the stateful API reference.

This example is available as a Jupyter notebook here.

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax  # https://github.com/deepmind/optax
# This model is just a weird mish-mash of stateful and non-stateful layers for
# demonstration purposes, it isn't doing any clever.
class Model(eqx.Module):
    norm1: eqx.nn.BatchNorm
    spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]
    norm2: eqx.nn.BatchNorm
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jr.split(key, 4)
        self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
        self.spectral_linear = eqx.nn.SpectralNorm(
            layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1),
            weight_name="weight",
            key=key2,
        )
        self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
        self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3)
        self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4)

    def __call__(self, x, state):
        x, state = self.norm1(x, state)
        x, state = self.spectral_linear(x, state)
        x = jax.nn.relu(x)
        x, state = self.norm2(x, state)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state

We see from the above that we just define our models like normal. As advertised, we just need to thread the additional state object in and out of every call. An updated state object is returned.

There's really nothing special here about stateful layers. Equinox isn't special-casing them in any way. We thread state in and out, just like we're thread x in and out. In fact calling it "state" is really just a matter of how it's advertised!


Alright, now let's see how we might train this model. This is also much like normal.

Note the use of in_axes and out_axes: our data is batched, but our model state isn't batched -- just like how our model parameters isn't batched.

Note how the axis_name argment matches the axis_name argument that the BatchNorm layers were initialised with. This tells BatchNorm which vmap'd axis it should compute statistics over. (This is a detail specific to BatchNorm, and is unrelated to stateful operations in general.)

def compute_loss(model, state, xs, ys):
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(xs, state)
    loss = jnp.mean((pred_ys - ys) ** 2)
    return loss, state


@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
    grads, state = eqx.filter_grad(compute_loss, has_aux=True)(model, state, xs, ys)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state

And now, let's see how we initialise this model, and initialise its state.

dataset_size = 10
learning_rate = 3e-4
steps = 5
seed = 5678

key = jr.PRNGKey(seed)
mkey, xkey, xkey2 = jr.split(key, 3)

model, state = eqx.nn.make_with_state(Model)(mkey)

xs = jr.normal(xkey, (dataset_size, 3))
ys = jnp.sin(xs) + 1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

for _ in range(steps):
    # Full-batch gradient descent in this simple example.
    model, state, opt_state = make_step(model, state, opt_state, xs, ys)

What is eqx.nn.make_with_state doing?

Here we come to the only interesting bit about using stateful layers!

When we initialise the model -- e.g. if we were to call Model(mkey) directly -- then the model PyTree would be initialised containing both (a) the initial parameters, and (b) the initial state. So make_with_state simply calls this, and then separates these two things. The returned model is a PyTree holding all the initial parameters (just like any other model), and state is a PyTree holding the initial state.


Finally, let's use our trained model to perform inference.

Remember to set the inference flag! Some layers have different behaviour between training and inference, and BatchNorm is one of these. (This is a detail specific to layers like BatchNorm and equinox.nn.Dropout, and is unrelated to stateful operations in general.)

We also fix the final state in the model, using equinox.Partial. The resulting inference_model is a PyTree (specifically, an equinox.Partial) containing both model and state.

inference_model = eqx.nn.inference_mode(model)
inference_model = eqx.Partial(inference_model, state=state)


@eqx.filter_jit
def evaluate(model, xs):
    # discard state
    out, _ = jax.vmap(model)(xs)
    return out


test_dataset_size = 5
test_xs = jr.normal(xkey2, (test_dataset_size, 3))
pred_test_ys = evaluate(inference_model, test_xs)

Here, we don't need the updated state object that is produced, so we just discard it.