Skip to content

Stateful operations¤

These are the tools that underlie stateful operations, like equinox.nn.BatchNorm or equinox.nn.SpectralNorm. These are fairly unusual layers, so most users will not need this part of the API.

Example

The stateful example is a good reference for the typical workflow for stateful layers.


equinox.nn.make_with_state(make_model: Callable[~_P, ~_T]) -> Callable[~_P, tuple[~_T, State]] ¤

This function is the most common API for working with stateful models. This initialises both the parameters and the state of a stateful model.

eqx.nn.make_with_state(Model)(*args, **kwargs) simply calls model_with_state = Model(*args, **kwargs), and then partitions the resulting PyTree into two pieces: the parameters, and the state.

Arguments:

  • make_model: some callable returning a PyTree.

Returns:

A callable, which when evaluated returns a 2-tuple of (model, state), where model is the result of make_model(*args, **kwargs) but with all of the initial states stripped out, and state is an equinox.nn.State object encapsulating the initial states.

Example

See the stateful example for a runnable example.

class Model(eqx.Module):
    def __init__(self, foo, bar):
        ...

    ...

model, state = eqx.nn.make_with_state(Model)(foo=3, bar=4)

Extra features¤

Let's explain how this works under the hood. First of all, all stateful layers (BatchNorm etc.) include an "index". This is basically just a unique hashable value (used later as a dictionary key), and an initial value for the state:

equinox.nn.StateIndex (Module) ¤

This wraps together (a) a unique dictionary key used for looking up a stateful value, and (b) how that stateful value should be initialised.

Example

class MyStatefulLayer(eqx.Module):
    index: eqx.nn.StateIndex

    def __init__(self):
        init_state = jnp.array(0)
        self.index = eqx.nn.StateIndex(init_state)

    def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
        current_state = state.get(self.index)
        new_x = x + current_state
        new_state = state.set(self.index, current_state + 1)
        return new_x, new_state

See also e.g. the source code of built-in stateful layers like equinox.nn.BatchNorm for further reference.

__init__(self, init: ~_Value) ¤

Arguments:

  • init: The initial value for the state.

This State object that's being passed around is essentially just a dictionary, mapping from StateIndexs to PyTrees-of-arrays. Correspondingly this has .get and .set methods to read and write values to it.

equinox.nn.State ¤

Stores the state of a model. For example, the running statistics of all equinox.nn.BatchNorm layers in the model.

This is essentially a dictionary mapping from equinox.nn.StateIndexs to PyTrees of arrays.

This class should be initialised via equinox.nn.make_with_state.

get(self, item: StateIndex[~_Value]) -> ~_Value ¤

Given an equinox.nn.StateIndex, returns the value of its state.

Arguments:

Returns:

The current state associated with that index.

set(self, item: StateIndex[~_Value], value: ~_Value) -> State ¤

Sets a new value for an equinox.nn.StateIndex, and returns the updated state.

Arguments:

Returns:

A new equinox.nn.State object, with the update.

As a safety guard against accidentally writing state.set(item, value) without assigning it to a new value, then the old object (self) will become invalid.

substate(self, pytree: PyTree) -> State ¤

Creates a smaller State object, that tracks only the states of some smaller part of the overall model.

Arguments:

  • pytree: any PyTree. It will be iterated over to check for [equinox.nn.StateIndex]s.

Returns:

A new equinox.nn.State object, which tracks only some of the overall states.

update(self, substate: State) -> State ¤

Takes a smaller State object (typically produces via .substate), and updates states by using all of its values.

Arguments:

  • substate: a State object whose keys are a subset of the keys of self.

Returns:

A new equinox.nn.State object, containing all of the updated values.

As a safety guard against accidentally writing state.set(item, value) without assigning it to a new value, then the old object (self) will become invalid.

Custom stateful layers¤

Let's use equinox.nn.StateIndex to create a custom stateful layer.

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array

class Counter(eqx.Module):
    index: eqx.nn.StateIndex

    def __init__(self):
        init_state = jnp.array(0)
        self.index = eqx.nn.StateIndex(init_state)

    def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
        value = state.get(self.index)
        new_x = x + value
        new_state = state.set(self.index, value + 1)
        return new_x, new_state

counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)

num_calls = state.get(counter.index)
print(f"Called {num_calls} times.")  # 0

_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.")  # 1

_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.")  # 2

Vmap'd stateful layers¤

This is an advanced thing to do! Here we'll build on the ensembling guide, and see how how we can create vmap'd stateful layers.

This follows on from the previous example, in which we define Counter.

import jax.random as jr

class Model(eqx.Module):
    linear: eqx.nn.Linear
    counter: Counter
    v_counter: Counter

    def __init__(self, key):
        # Not-stateful layer
        self.linear = eqx.nn.Linear(2, 2, key=key)
        # Stateful layer.
        self.counter = Counter()
        # Vmap'd stateful layer. (Whose initial state will include a batch dimension.)
        self.v_counter = eqx.filter_vmap(Counter, axis_size=2)()

    def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
        # This bit happens as normal.
        assert x.shape == (2,)
        x = self.linear(x)
        x, state = self.counter(x, state)

        # For the vmap, we have to restrict our state to just those states we want to
        # vmap, and then update the overall state again afterwards.
        #
        # After all, the state for `self.counter` isn't expecting to be batched, so we
        # have to remove that.
        substate = state.substate(self.v_counter)
        x, substate = eqx.filter_vmap(self.v_counter)(x, substate)
        state = state.update(substate)

        return x, state

key = jr.PRNGKey(0)
model, state = eqx.nn.make_with_state(Model)(key)
x = jnp.array([5.0, -1.0])
model(x, state)