Skip to content

Sequential¤

"Sequential" is a common pattern in neural network frameworks, indicating a sequence of layers applied in order.

These are useful when building fairly straightforward models. But for anything nontrivial, subclass equinox.Module instead.


equinox.nn.Sequential (StatefulLayer) ¤

A sequence of equinox.Modules applied in order.

Note

Activation functions can be added by wrapping them in equinox.nn.Lambda.

__init__(self, layers: Sequence[Callable]) ¤

Arguments:

__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]] ¤

Arguments:

  • x: passed to the first member of the sequence.
  • state: If provided, then it is passed to, and updated from, any layer which subclasses equinox.nn.StatefulLayer.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns: The output of the last member of the sequence.

If state is passed, then a 2-tuple of (output, state) is returned. If state is not passed, then just the output is returned.


equinox.nn.Lambda (Module) ¤

Wraps a callable (e.g. an activation function) for use with equinox.nn.Sequential.

Precisely, this just adds an extra key argument (that is ignored). Given some function fn, then Lambda is essentially a convenience for lambda x, key: f(x).

Faq

If you get a TypeError saying the function is not a valid JAX type, see the FAQ.

Example

   model = eqx.nn.Sequential(
       [
           eqx.nn.Linear(...),
           eqx.nn.Lambda(jax.nn.relu),
           ...
       ]
   )
__init__(self, fn: Callable[[Any], Any]) ¤

Arguments:

__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: The input JAX array.
  • key: Ignored.

Returns:

The output of the fn(x) operation.


equinox.nn.StatefulLayer (Module) ¤

An abstract base class, used by equinox.nn.Sequential, to mark that a layer might be stateful. If Sequential sees that a layer inherits from StatefulLayer, then it will call layer.is_stateful() to check whether to call the layer as new_x = layer(x) or (new_x, new_state) = layer(x, state).

is_stateful(self) -> bool ¤

Indicates whether this layer should be considered stateful.

The default implementation just returns True, but subclasses may override this to provide custom logic if the layer is only "maybe stateful". (E.g. if they optionally use stateful sublayers themselves.)

Arguments:

None

Returns:

A boolean. True indicates that the layer should be called as (new_x, new_state) = layer(x, state). False indicates that the layer should be called as new_x = layer(x).