Sequential¤
"Sequential" is a common pattern in neural network frameworks, indicating a sequence of layers applied in order.
These are useful when building fairly straightforward models. But for anything nontrivial, subclass equinox.Module
instead.
equinox.nn.Sequential (StatefulLayer)
¤
A sequence of equinox.Module
s applied in order.
Note
Activation functions can be added by wrapping them in equinox.nn.Lambda
.
__init__(self, layers: Sequence[Callable])
¤
Arguments:
layers
: A sequence ofequinox.Module
s.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]
¤
Arguments:
x
: passed to the first member of the sequence.state
: If provided, then it is passed to, and updated from, any layer which subclassesequinox.nn.StatefulLayer
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns: The output of the last member of the sequence.
If state
is passed, then a 2-tuple of (output, state)
is returned.
If state
is not passed, then just the output is returned.
equinox.nn.Lambda (Module)
¤
Wraps a callable (e.g. an activation function) for use with
equinox.nn.Sequential
.
Precisely, this just adds an extra key
argument (that is ignored). Given some
function fn
, then Lambda
is essentially a convenience for lambda x, key: f(x)
.
Faq
If you get a TypeError saying the function is not a valid JAX type, see the FAQ.
Example
model = eqx.nn.Sequential(
[
eqx.nn.Linear(...),
eqx.nn.Lambda(jax.nn.relu),
...
]
)
__init__(self, fn: Callable[[Any], Any])
¤
Arguments:
fn
: A callable to be wrapped inequinox.Module
.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: The input JAX array.key
: Ignored.
Returns:
The output of the fn(x)
operation.
equinox.nn.StatefulLayer (Module)
¤
An abstract base class, used by equinox.nn.Sequential
, to mark that a
layer might be stateful. If Sequential
sees that a layer inherits from
StatefulLayer
, then it will call layer.is_stateful()
to check whether to
call the layer as new_x = layer(x)
or (new_x, new_state) = layer(x, state)
.
is_stateful(self) -> bool
¤
Indicates whether this layer should be considered stateful.
The default implementation just returns True, but subclasses may override this to provide custom logic if the layer is only "maybe stateful". (E.g. if they optionally use stateful sublayers themselves.)
Arguments:
None
Returns:
A boolean. True
indicates that the layer should be called as
(new_x, new_state) = layer(x, state)
. False
indicates that the layer
should be called as new_x = layer(x)
.