Skip to content

Stateful operations¤

These operations can be used to introduce save/load JAX arrays as a side-effect of JAX operations, even under JIT.

Warning

This is considered experimental.

Stateful operations will not produce correct results under jax.checkpoint or jax.pmap.

Danger

Really, this is experimental. Side effects can easily make your code do something unexpected. Whatever you're doing, you almost certainly do not need this.

Use cases:

  • Something like equinox.experimental.BatchNorm, for which we would like to save the running statistics as a side-effect.
  • Implicitly passing information between loop iterations -- i.e. rather than explicitly via the carry argument to lax.scan. Perhaps you're using a third-party library that handles the lax.scan, that doesn't allow you pass your own information between iterations.

Example

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp

index = eqx.experimental.StateIndex()
init = jnp.array(0)
eqx.experimental.set_state(index, init)

@jax.jit
def scan_fun(_, __):
    val = eqx.experimental.get_state(index, like=init)
    val = val + 1
    eqx.experimental.set_state(index, val)
    return None, val

_, out = lax.scan(scan_fun, None, xs=None, length=5)
print(out)  # [1 2 3 4 5]

equinox.experimental.StateIndex (Module) ¤

An index for setting or getting a piece of state with equinox.experimental.get_state or equinox.experimental.set_state.

You should typically treat this like a model parameter.

Example

import equinox as eqx
import equinox.experimental as eqxe
import jax.numpy as jnp

class CacheInput(eqx.Module):
    index: eqxe.StateIndex

    def __init__(self, input_shape):
        self.index = eqxe.StateIndex()
        eqxe.set_state(self.index, jnp.zeros(input_shape))

    def __call__(self, x):
        last_x = eqxe.get_state(self.index, x)
        eqxe.set_state(self.index, x)
        print(f"last_x={last_x}, x={x}")

x = jnp.array([1., 2.])
y = jnp.array([3., 4.])
shape = x.shape
ci = CacheInput(shape)
ci(x)  # last_x=[0. 0.], x=[1. 2.]
ci(y)  # last_x=[1. 2.], x=[3. 4.]
__init__(self, inference: bool = False) ¤

Arguments:

  • inference: If True, then the state can only be get, but not set. All stored states will looked up when crossing the JIT boundary -- rather than dynamically at runtime -- and treated as inputs to the XLA computation graph. This improves speed at runtime. This may be toggled with equinox.tree_inference.

Warning

You should not modify the inference flag whilst inside a JIT region. For example, the following will produced undefined behaviour:

@jax.jit
def f(...):
    ...
    index = eqx.tree_at(lambda i: i.inference, index, True)
    ...

equinox.experimental.get_state(index: StateIndex, like: PyTree[Array]) -> PyTree[Array] ¤

Get some previously saved state.

Arguments:

  • index: The index of the state to look up. Should be an instance of equinox.experimental.StateIndex.
  • like: A PyTree of JAX arrays of the same shape, dtype, PyTree structure, and batch axes as the state being looked up.

Returns:

Whatever the previously saved state is.

Raises:

A TypeError at trace time if like is not a PyTree of JAX arrays.

A RuntimeError at run time if like is not of the same shape, dtype, PyTree structure, and batch axes as the retrieved value.

A RuntimeError at run time if no state has previously been saved with this index.

Warning

This means that your operation will no longer be a pure function.


equinox.experimental.set_state(index: StateIndex, state: PyTree[Array]) -> None ¤

Save a PyTree of JAX arrays as a side-effect.

Arguments:

Returns:

None.

Raises:

A RuntimeError at run time if this index has previously been used to save a state with a different shape, dtype, PyTree structure, or batch axes.

A RuntimeError at trace time if index.inference is truthy.

A TypeError at trace time if state is not a PyTree of JAX arrays.

A NotImplementedError at trace time if trying to compute a gradient through state.

Info

The same index can be used multiple times, to overwrite a previously saved value. The new and old state must both have the same PyTree structure, however.

Warning

Note that state cannot be differentiated.

Warning

This means that your operation will no longer be a pure function. Moreover note that the saving-as-a-side-effect may occur even when set_state is wrapped in lax.cond etc. (As e.g. under vmap then lax.cond is transformed into lax.select.)