Sequential¤
"Sequential" is a common pattern in neural network frameworks, indicating a sequence of layers applied in order.
These are useful when building fairly straightforward models. But for anything nontrivial, subclass equinox.Module
instead.
equinox.nn.Sequential (StatefulLayer)
¤
A sequence of equinox.Module
s applied in order.
Note
Activation functions can be added by wrapping them in equinox.nn.Lambda
.
equinox.nn.Lambda (Module)
¤
Wraps a callable (e.g. an activation function) for use with
equinox.nn.Sequential
.
Precisely, this just adds an extra key
argument (that is ignored). Given some
function fn
, then Lambda
is essentially a convenience for lambda x, key: f(x)
.
Faq
If you get a TypeError saying the function is not a valid JAX type, see the FAQ.
Example
model = eqx.nn.Sequential(
[
eqx.nn.Linear(...),
eqx.nn.Lambda(jax.nn.relu),
...
]
)
equinox.nn.StatefulLayer (Module)
¤
An abstract base class, used by equinox.nn.Sequential
, to mark that a layer
might be stateful. If Sequential
sees that a layer inherits from StatefulLayer
,
then it will call layer.is_stateful()
to check whether to call the layer as
new_x = layer(x)
or (new_x, new_state) = layer(x, state)
.
is_stateful(self) -> bool
¤
Indicates whether this layer should be considered stateful.
The default implementation just returns True, but subclasses may override this to provide custom logic if the layer is only "maybe stateful". (E.g. if they optioanlly use stateful sublayers themselves.)
Arguments:
None
Returns:
A boolean. True
indicates that the layer should be called as
(new_x, new_state) = layer(x, state)
. False
indicates that the layer should
be called as new_x = layer(x)
.