Stateful operations (e.g. BatchNorm)¤
Some layers, such as
equinox.nn.BatchNorm are sometimes called "stateful": this refers to the fact that they take an additional input (in the case of BatchNorm, the running statistics) and return an additional output (the updated running statistics).
See also the stateful API reference.
This example is available as a Jupyter notebook here.
import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import optax # https://github.com/deepmind/optax
# This model is just a weird mish-mash of stateful and non-stateful layers for # demonstration purposes, it isn't doing any clever. class Model(eqx.Module): norm1: eqx.nn.BatchNorm spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear] norm2: eqx.nn.BatchNorm linear1: eqx.nn.Linear linear2: eqx.nn.Linear def __init__(self, key): key1, key2, key3, key4 = jr.split(key, 4) self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch") self.spectral_linear = eqx.nn.SpectralNorm( layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1), weight_name="weight", key=key2, ) self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch") self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3) self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4) def __call__(self, x, state): x, state = self.norm1(x, state) x, state = self.spectral_linear(x, state) x = jax.nn.relu(x) x, state = self.norm2(x, state) x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return x, state
We see from the above that we just define our models like normal. As advertised, we just need to thread the additional
state object in and out of every call. An updated state object is returned.
There's really nothing special here about stateful layers. Equinox isn't special-casing them in any way. We thread
state in and out, just like we're thread
x in and out. In fact calling it "state" is really just a matter of how it's advertised!
Alright, now let's see how we might train this model. This is also much like normal.
Note the use of
out_axes: our data is batched, but our model state isn't batched -- just like how our model parameters isn't batched.
Note how the
axis_name argment matches the
axis_name argument that the
BatchNorm layers were initialised with. This tells
BatchNorm which vmap'd axis it should compute statistics over. (This is a detail specific to
BatchNorm, and is unrelated to stateful operations in general.)
def compute_loss(model, state, xs, ys): batch_model = jax.vmap( model, axis_name="batch", in_axes=(0, None), out_axes=(0, None) ) pred_ys, state = batch_model(xs, state) loss = jnp.mean((pred_ys - ys) ** 2) return loss, state @eqx.filter_jit def make_step(model, state, opt_state, xs, ys): grads, state = eqx.filter_grad(compute_loss, has_aux=True)(model, state, xs, ys) updates, opt_state = optim.update(grads, opt_state) model = eqx.apply_updates(model, updates) return model, state, opt_state
And now, let's see how we initialise this model, and initialise its state.
dataset_size = 10 learning_rate = 3e-4 steps = 5 seed = 5678 key = jr.PRNGKey(seed) mkey, xkey, xkey2 = jr.split(key, 3) model, state = eqx.nn.make_with_state(Model)(mkey) xs = jr.normal(xkey, (dataset_size, 3)) ys = jnp.sin(xs) + 1 optim = optax.adam(learning_rate) opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) for _ in range(steps): # Full-batch gradient descent in this simple example. model, state, opt_state = make_step(model, state, opt_state, xs, ys)
Here we come to the only interesting bit about using stateful layers!
When we initialise the model -- e.g. if we were to call
Model(mkey) directly -- then the model PyTree would be initialised containing both (a) the initial parameters, and (b) the initial state. So
make_with_state simply calls this, and then separates these two things. The returned
model is a PyTree holding all the initial parameters (just like any other model), and
state is a PyTree holding the initial state.
Finally, let's use our trained model to perform inference.
Remember to set the inference flag! Some layers have different behaviour between training and inference, and
BatchNorm is one of these. (This is a detail specific to layers like
equinox.nn.Dropout, and is unrelated to stateful operations in general.)
We also fix the final state in the model, using
equinox.Partial. The resulting
inference_model is a PyTree (specifically, an
equinox.Partial) containing both
inference_model = eqx.nn.inference_mode(model) inference_model = eqx.Partial(inference_model, state=state) @eqx.filter_jit def evaluate(model, xs): # discard state out, _ = jax.vmap(model)(xs) return out test_dataset_size = 5 test_xs = jr.normal(xkey2, (test_dataset_size, 3)) pred_test_ys = evaluate(inference_model, test_xs)
Here, we don't need the updated state object that is produced, so we just discard it.