equinox.nn.inference_mode(pytree: PyTree, value: bool = True) -> PyTree ¤

Convenience function for setting all inference attributes.

inference flags are used to toggle the behaviour of a number of the pre-built neural network layers, such as equinox.nn.Dropout or equinox.nn.BatchNorm.


class Model(eqx.Module):
    norm: eqx.nn.BatchNorm
    dropout: eqx.nn.Dropout
    linear: eqx.nn.Linear

    def __init__(self, key):
        key1, key2 = jax.random.split(key)
        self.norm = eqx.nn.BatchNorm(3, "batch", key=key1)
        self.dropout = eqx.nn.Dropout(0.4)
        self.linear = eqx.nn.Linear(3, 1, key=key2)

    def __call__(self, x, ctx, *, key):
        x, ctx = self.norm(x, ctx)
        x = self.dropout(x, key=key)
        x = self.linear(x)
        return x, ctx

training_model = Model(jax.random.PRNGKey(0))
inference_model = eqx.nn.inference_mode(training_model)
training_model_again = eqx.nn.inference_mode(inference_model, value=False)

This function is essentially equivalent to:

has_inference = lambda leaf: hasattr(leaf, "inference")

def where(pytree):
    return tuple(x.inference
                 for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
                 if has_inference(x))

inference_pytree = equinox.tree_at(where, pytree, replace_fn=lambda _: value)


  • pytree: the PyTree to modify.
  • value: the value to set all inference attributes to. Defaults to True, i.e. inference mode.


A copy of pytree with all inference flags set to value.