Stateful operations¤
These are the tools that underlie stateful operations, like equinox.nn.BatchNorm
or equinox.nn.SpectralNorm
. These are fairly unusual layers, so most users will not need this part of the API.
Example
The stateful example is a good reference for the typical workflow for stateful layers.
equinox.nn.make_with_state(make_model: Callable[~_P, ~_T]) -> Callable[~_P, tuple[~_T, State]]
¤
This function is the most common API for working with stateful models. This initialises both the parameters and the state of a stateful model.
eqx.nn.make_with_state(Model)(*args, **kwargs)
simply calls
model_with_state = Model(*args, **kwargs)
, and then partitions the resulting
PyTree into two pieces: the parameters, and the state.
Arguments:
make_model
: some callable returning a PyTree.
Returns:
A callable, which when evaluated returns a 2-tuple of (model, state)
, where
model
is the result of make_model(*args, **kwargs)
but with all of the initial
states stripped out, and state
is an equinox.nn.State
object encapsulating
the initial states.
Example
See the stateful example for a runnable example.
class Model(eqx.Module):
def __init__(self, foo, bar):
...
...
model, state = eqx.nn.make_with_state(Model)(foo=3, bar=4)
Extra features¤
Let's explain how this works under the hood. First of all, all stateful layers (BatchNorm
etc.) include an "index". This is basically just a unique hashable value (used later as a dictionary key), and an initial value for the state:
equinox.nn.StateIndex (Module)
¤
This wraps together (a) a unique dictionary key used for looking up a stateful value, and (b) how that stateful value should be initialised.
Example
class MyStatefulLayer(eqx.Module):
index: eqx.nn.StateIndex
def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
current_state = state.get(self.index)
new_x = x + current_state
new_state = state.set(self.index, current_state + 1)
return new_x, new_state
See also e.g. the source code of built-in stateful layers like
equinox.nn.BatchNorm
for further reference.
__init__(self, init: ~_Value)
¤
Arguments:
init
: The initial value for the state.
This State
object that's being passed around is essentially just a dictionary, mapping from StateIndex
s to PyTrees-of-arrays. Correspondingly this has .get
and .set
methods to read and write values to it.
equinox.nn.State
¤
Stores the state of a model. For example, the running statistics of all
equinox.nn.BatchNorm
layers in the model.
This is essentially a dictionary mapping from equinox.nn.StateIndex
s to
PyTrees of arrays.
This class should be initialised via equinox.nn.make_with_state
.
get(self, item: StateIndex[~_Value]) -> ~_Value
¤
Given an equinox.nn.StateIndex
, returns the value of its state.
Arguments:
item
: anequinox.nn.StateIndex
.
Returns:
The current state associated with that index.
set(self, item: StateIndex[~_Value], value: ~_Value) -> State
¤
Sets a new value for an equinox.nn.StateIndex
, and returns the
updated state.
Arguments:
item
: anequinox.nn.StateIndex
.value
: the new value associated with that index.
Returns:
A new equinox.nn.State
object, with the update.
As a safety guard against accidentally writing state.set(item, value)
without
assigning it to a new value, then the old object (self
) will become invalid.
substate(self, pytree: PyTree) -> State
¤
Creates a smaller State
object, that tracks only the states of some smaller
part of the overall model.
Arguments:
pytree
: any PyTree. It will be iterated over to check for [equinox.nn.StateIndex
]s.
Returns:
A new equinox.nn.State
object, which tracks only some of the overall
states.
update(self, substate: State) -> State
¤
Takes a smaller State
object (typically produces via .substate
), and
updates states by using all of its values.
Arguments:
substate
: aState
object whose keys are a subset of the keys ofself
.
Returns:
A new equinox.nn.State
object, containing all of the updated values.
As a safety guard against accidentally writing state.set(item, value)
without
assigning it to a new value, then the old object (self
) will become invalid.
Custom stateful layers¤
Let's use equinox.nn.StateIndex
to create a custom stateful layer.
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array
class Counter(eqx.Module):
index: eqx.nn.StateIndex
def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
value = state.get(self.index)
new_x = x + value
new_state = state.set(self.index, value + 1)
return new_x, new_state
counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 0
_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 1
_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 2
Vmap'd stateful layers¤
This is an advanced thing to do! Here we'll build on the ensembling guide, and see how how we can create vmap'd stateful layers.
This follows on from the previous example, in which we define Counter
.
import jax.random as jr
class Model(eqx.Module):
linear: eqx.nn.Linear
counter: Counter
v_counter: Counter
def __init__(self, key):
# Not-stateful layer
self.linear = eqx.nn.Linear(2, 2, key=key)
# Stateful layer.
self.counter = Counter()
# Vmap'd stateful layer. (Whose initial state will include a batch dimension.)
self.v_counter = eqx.filter_vmap(Counter, axis_size=2)()
def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
# This bit happens as normal.
assert x.shape == (2,)
x = self.linear(x)
x, state = self.counter(x, state)
# For the vmap, we have to restrict our state to just those states we want to
# vmap, and then update the overall state again afterwards.
#
# After all, the state for `self.counter` isn't expecting to be batched, so we
# have to remove that.
substate = state.substate(self.v_counter)
x, substate = eqx.filter_vmap(self.v_counter)(x, substate)
state = state.update(substate)
return x, state
key = jr.PRNGKey(0)
model, state = eqx.nn.make_with_state(Model)(key)
x = jnp.array([5.0, -1.0])
model(x, state)