# Stateful operations (e.g. BatchNorm)ยค

Some layers, such as `equinox.nn.BatchNorm`

are sometimes called "stateful": this refers to the fact that they take an additional input (in the case of BatchNorm, the running statistics) and return an additional output (the updated running statistics).

This just means that we need to plumb an extra input and output through our models. This example demonstrates both `equinox.nn.BatchNorm`

and `equinox.nn.SpectralNorm`

.

See also the stateful API reference.

This example is available as a Jupyter notebook here.

```
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax # https://github.com/deepmind/optax
```

```
# This model is just a weird mish-mash of stateful and non-stateful layers for
# demonstration purposes, it isn't doing any clever.
class Model(eqx.Module):
norm1: eqx.nn.BatchNorm
spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]
norm2: eqx.nn.BatchNorm
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
def __init__(self, key):
key1, key2, key3, key4 = jr.split(key, 4)
self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
self.spectral_linear = eqx.nn.SpectralNorm(
layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1),
weight_name="weight",
key=key2,
)
self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3)
self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4)
def __call__(self, x, state):
x, state = self.norm1(x, state)
x, state = self.spectral_linear(x, state)
x = jax.nn.relu(x)
x, state = self.norm2(x, state)
x = self.linear1(x)
x = jax.nn.relu(x)
x = self.linear2(x)
return x, state
```

We see from the above that we just define our models like normal. As advertised, we just need to thread the additional `state`

object in and out of every call. An updated state object is returned.

There's really nothing special here about stateful layers. Equinox isn't special-casing them in any way. We thread `state`

in and out, just like we're thread `x`

in and out. In fact calling it "state" is really just a matter of how it's advertised!

Alright, now let's see how we might train this model. This is also much like normal.

Note the use of `in_axes`

and `out_axes`

: our data is batched, but our model state isn't batched -- just like how our model parameters isn't batched.

Note how the `axis_name`

argment matches the `axis_name`

argument that the `BatchNorm`

layers were initialised with. This tells `BatchNorm`

which vmap'd axis it should compute statistics over. (This is a detail specific to `BatchNorm`

, and is unrelated to stateful operations in general.)

```
def compute_loss(model, state, xs, ys):
batch_model = jax.vmap(
model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
)
pred_ys, state = batch_model(xs, state)
loss = jnp.mean((pred_ys - ys) ** 2)
return loss, state
@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
grads, state = eqx.filter_grad(compute_loss, has_aux=True)(model, state, xs, ys)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, state, opt_state
```

And now, let's see how we initialise this model, and initialise its state.

```
dataset_size = 10
learning_rate = 3e-4
steps = 5
seed = 5678
key = jr.PRNGKey(seed)
mkey, xkey, xkey2 = jr.split(key, 3)
model, state = eqx.nn.make_with_state(Model)(mkey)
xs = jr.normal(xkey, (dataset_size, 3))
ys = jnp.sin(xs) + 1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
for _ in range(steps):
# Full-batch gradient descent in this simple example.
model, state, opt_state = make_step(model, state, opt_state, xs, ys)
```

What is `eqx.nn.make_with_state`

doing?

Here we come to the only interesting bit about using stateful layers!

When we initialise the model -- e.g. if we were to call `Model(mkey)`

directly -- then the model PyTree would be initialised containing both (a) the initial parameters, and (b) the initial state. So `make_with_state`

simply calls this, and then separates these two things. The returned `model`

is a PyTree holding all the initial parameters (just like any other model), and `state`

is a PyTree holding the initial state.

Finally, let's use our trained model to perform inference.

Remember to set the inference flag! Some layers have different behaviour between training and inference, and `BatchNorm`

is one of these. (This is a detail specific to layers like `BatchNorm`

and `equinox.nn.Dropout`

, and is unrelated to stateful operations in general.)

We also fix the final state in the model, using `equinox.Partial`

. The resulting `inference_model`

is a PyTree (specifically, an `equinox.Partial`

) containing both `model`

and `state`

.

```
inference_model = eqx.nn.inference_mode(model)
inference_model = eqx.Partial(inference_model, state=state)
@eqx.filter_jit
def evaluate(model, xs):
# discard state
out, _ = jax.vmap(model)(xs)
return out
test_dataset_size = 5
test_xs = jr.normal(xkey2, (test_dataset_size, 3))
pred_test_ys = evaluate(inference_model, test_xs)
```

Here, we don't need the updated state object that is produced, so we just discard it.