Stateful operations¤
These operations can be used to introduce save/load JAX arrays as a side-effect of JAX operations, even under JIT.
Warning
This is considered experimental.
Stateful operations will not produce correct results under jax.checkpoint
or jax.pmap
.
Danger
Really, this is experimental. Side effects can easily make your code do something unexpected. Whatever you're doing, you almost certainly do not need this.
Use cases:
- Something like
equinox.experimental.BatchNorm
, for which we would like to save the running statistics as a side-effect. - Implicitly passing information between loop iterations -- i.e. rather than explicitly via the
carry
argument tolax.scan
. Perhaps you're using a third-party library that handles thelax.scan
, that doesn't allow you pass your own information between iterations.
Example
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
index = eqx.experimental.StateIndex()
init = jnp.array(0)
eqx.experimental.set_state(index, init)
@jax.jit
def scan_fun(_, __):
val = eqx.experimental.get_state(index, like=init)
val = val + 1
eqx.experimental.set_state(index, val)
return None, val
_, out = lax.scan(scan_fun, None, xs=None, length=5)
print(out) # [1 2 3 4 5]
equinox.experimental.StateIndex (Module)
¤
An index for setting or getting a piece of state with
equinox.experimental.get_state
or equinox.experimental.set_state
.
You should typically treat this like a model parameter.
Example
import equinox as eqx
import equinox.experimental as eqxe
import jax.numpy as jnp
class CacheInput(eqx.Module):
index: eqxe.StateIndex
def __init__(self, input_shape):
self.index = eqxe.StateIndex()
eqxe.set_state(self.index, jnp.zeros(input_shape))
def __call__(self, x):
last_x = eqxe.get_state(self.index, x)
eqxe.set_state(self.index, x)
print(f"last_x={last_x}, x={x}")
x = jnp.array([1., 2.])
y = jnp.array([3., 4.])
shape = x.shape
ci = CacheInput(shape)
ci(x) # last_x=[0. 0.], x=[1. 2.]
ci(y) # last_x=[1. 2.], x=[3. 4.]
__init__(self, inference: bool = False)
¤
Arguments:
inference
: IfTrue
, then the state can only be get, but not set. All stored states will looked up when crossing the JIT boundary -- rather than dynamically at runtime -- and treated as inputs to the XLA computation graph. This improves speed at runtime. This may be toggled withequinox.tree_inference
.
Warning
You should not modify the inference
flag whilst inside a JIT region. For
example, the following will produced undefined behaviour:
@jax.jit
def f(...):
...
index = eqx.tree_at(lambda i: i.inference, index, True)
...
equinox.experimental.get_state(index: StateIndex, like: PyTree[Array]) -> PyTree[Array]
¤
Get some previously saved state.
Arguments:
index
: The index of the state to look up. Should be an instance ofequinox.experimental.StateIndex
.like
: A PyTree of JAX arrays of the same shape, dtype, PyTree structure, and batch axes as the state being looked up.
Returns:
Whatever the previously saved state is.
Raises:
A TypeError
at trace time if like
is not a PyTree of JAX arrays.
A RuntimeError
at run time if like
is not of the same shape, dtype, PyTree
structure, and batch axes as the retrieved value.
A RuntimeError
at run time if no state has previously been saved with this
index
.
Warning
This means that your operation will no longer be a pure function.
equinox.experimental.set_state(index: StateIndex, state: PyTree[Array]) -> None
¤
Save a PyTree of JAX arrays as a side-effect.
Arguments:
index
: A key under which to save the state. Should be an instance ofequinox.experimental.StateIndex
.state
: An PyTree of JAX arrays to save.
Returns:
None
.
Raises:
A RuntimeError
at run time if this index
has previously been used to save a
state
with a different shape, dtype, PyTree structure, or batch axes.
A RuntimeError
at trace time if index.inference
is truthy.
A TypeError
at trace time if state
is not a PyTree of JAX arrays.
A NotImplementedError
at trace time if trying to compute a gradient through
state
.
Info
The same index
can be used multiple times, to overwrite a previously saved
value. The new and old state
must both have the same PyTree structure, however.
Warning
Note that state
cannot be differentiated.
Warning
This means that your operation will no longer be a pure function. Moreover note
that the saving-as-a-side-effect may occur even when set_state
is wrapped in
lax.cond
etc. (As e.g. under vmap
then lax.cond
is transformed into
lax.select
.)