Stateful operations (e.g. BatchNorm)ยค
Some layers, such as equinox.nn.BatchNorm
are sometimes called "stateful": this refers to the fact that they take an additional input (in the case of BatchNorm, the running statistics) and return an additional output (the updated running statistics).
This just means that we need to plumb an extra input and output through our models. This example demonstrates both equinox.nn.BatchNorm
and equinox.nn.SpectralNorm
.
See also the stateful API reference.
This example is available as a Jupyter notebook here.
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # https://github.com/deepmind/optax
# This model is just a weird mish-mash of stateful and non-stateful layers for
# demonstration purposes, it isn't doing any clever.
class Model(eqx.Module):
norm1: eqx.nn.BatchNorm
spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]
norm2: eqx.nn.BatchNorm
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(self, key):
key1, key2, key3, key4 = jr.split(key, 4)
self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
self.spectral_linear = eqx.nn.SpectralNorm(
layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1),
weight_name="weight",
key=key2,
)
self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3)
self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4)
def __call__(self, x, state):
x, state = self.norm1(x, state)
x, state = self.spectral_linear(x, state)
x = jax.nn.relu(x)
x, state = self.norm2(x, state)
x = self.linear1(x)
x = jax.nn.relu(x)
x = self.linear2(x)
return x, state
We see from the above that we just define our models like normal. As advertised, we just need to thread the additional state
object in and out of every call. An updated state object is returned.
There's really nothing special here about stateful layers. Equinox isn't special-casing them in any way. We thread state
in and out, just like we're thread x
in and out. In fact calling it "state" is really just a matter of how it's advertised!
Alright, now let's see how we might train this model. This is also much like normal.
Note the use of in_axes
and out_axes
: our data is batched, but our model state isn't batched -- just like how our model parameters isn't batched.
Note how the axis_name
argument matches the axis_name
argument that the BatchNorm
layers were initialised with. This tells BatchNorm
which vmap'd axis it should compute statistics over. (This is a detail specific to BatchNorm
, and is unrelated to stateful operations in general.)
def compute_loss(model, state, xs, ys):
batch_model = jax.vmap(
model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
)
pred_ys, state = batch_model(xs, state)
loss = jnp.mean((pred_ys - ys) ** 2)
return loss, state
@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
grads, state = eqx.filter_grad(compute_loss, has_aux=True)(model, state, xs, ys)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, state, opt_state
And now, let's see how we initialise this model, and initialise its state.
dataset_size = 10
learning_rate = 3e-4
steps = 5
seed = 5678
key = jr.PRNGKey(seed)
mkey, xkey, xkey2 = jr.split(key, 3)
model, state = eqx.nn.make_with_state(Model)(mkey)
xs = jr.normal(xkey, (dataset_size, 3))
ys = jnp.sin(xs) + 1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
for _ in range(steps):
# Full-batch gradient descent in this simple example.
model, state, opt_state = make_step(model, state, opt_state, xs, ys)
What is eqx.nn.make_with_state
doing?
Here we come to the only interesting bit about using stateful layers!
When we initialise the model -- e.g. if we were to call Model(mkey)
directly -- then the model PyTree would be initialised containing both (a) the initial parameters, and (b) the initial state. So make_with_state
simply calls this, and then separates these two things. The returned model
is a PyTree holding all the initial parameters (just like any other model), and state
is a PyTree holding the initial state.
Finally, let's use our trained model to perform inference.
Remember to set the inference flag! Some layers have different behaviour between training and inference, and BatchNorm
is one of these. (This is a detail specific to layers like BatchNorm
and equinox.nn.Dropout
, and is unrelated to stateful operations in general.)
We also fix the final state in the model, using equinox.Partial
. The resulting inference_model
is a PyTree (specifically, an equinox.Partial
) containing both model
and state
.
inference_model = eqx.nn.inference_mode(model)
inference_model = eqx.Partial(inference_model, state=state)
@eqx.filter_jit
def evaluate(model, xs):
# discard state
out, _ = jax.vmap(model)(xs)
return out
test_dataset_size = 5
test_xs = jr.normal(xkey2, (test_dataset_size, 3))
pred_test_ys = evaluate(inference_model, test_xs)
Here, we don't need the updated state object that is produced, so we just discard it.