Training/Inference¤
equinox.nn.inference_mode(pytree: PyTree, value: bool = True) -> PyTree
¤
Convenience function for setting all inference
attributes.
inference
flags are used to toggle the behaviour of a number of the pre-built
neural network layers, such as equinox.nn.Dropout
or
equinox.nn.BatchNorm
.
Example
class Model(eqx.Module):
norm: eqx.nn.BatchNorm
dropout: eqx.nn.Dropout
linear: eqx.nn.Linear
def __init__(self, key):
key1, key2 = jax.random.split(key)
self.norm = eqx.nn.BatchNorm(3, "batch", key=key1)
self.dropout = eqx.nn.Dropout(0.4)
self.linear = eqx.nn.Linear(3, 1, key=key2)
def __call__(self, x, ctx, *, key):
x, ctx = self.norm(x, ctx)
x = self.dropout(x, key=key)
x = self.linear(x)
return x, ctx
training_model = Model(jax.random.PRNGKey(0))
inference_model = eqx.nn.inference_mode(training_model)
training_model_again = eqx.nn.inference_mode(inference_model, value=False)
This function is essentially equivalent to:
has_inference = lambda leaf: hasattr(leaf, "inference")
def where(pytree):
return tuple(x.inference
for x in jtu.tree_leaves(pytree, is_leaf=has_inference)
if has_inference(x))
inference_pytree = equinox.tree_at(where, pytree, replace_fn=lambda _: value)
Arguments:
pytree
: the PyTree to modify.value
: the value to set allinference
attributes to. Defaults toTrue
, i.e. inference mode.
Returns:
A copy of pytree
with all inference
flags set to value
.