Skip to content

Composites¤

equinox.nn.MLP (Module) ¤

Standard Multi-Layer Perceptron; also known as a feed-forward network.

__init__(self, in_size: int, out_size: int, width_size: int, depth: int, activation: Callable = <function relu>, final_activation: Callable = <function _identity>, *, key: jax.random.PRNGKey, **kwargs) ¤

Arguments:

  • in_size: The size of the input layer.
  • out_size: The size of the output layer.
  • width_size: The size of each hidden layer.
  • depth: The number of hidden layers.
  • activation: The activation function after each hidden layer. Defaults to ReLU.
  • final_activation: The activation function after the output layer. Defaults to the identity.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: A JAX array with shape (in_size,).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array with shape (out_size,).


equinox.nn.Sequential (Module) ¤

A sequence of equinox.Modules applied in order.

Note

Activation functions can be added by wrapping them in equinox.nn.Lambda.

__init__(self, layers: Sequence[Module]) ¤

Arguments:

__call__(self, x: Any, *, key: Optional[jax.random.PRNGKey] = None) -> Any ¤

Arguments:

  • x: Argument passed to the first member of the sequence.
  • key: A jax.random.PRNGKey, which will be split and passed to every layer to provide any desired randomness. (Optional. Keyword only argument.)

Returns:

The output of the last member of the sequence.


equinox.nn.Lambda (Module) ¤

Wraps a callable (e.g. an activation function) for use with equinox.nn.Sequential.

Precisely, this just adds an extra key argument (that is ignored). Given some function fn, then Lambda is essentially a convenience for lambda x, key: f(x).

Example

   model = eqx.nn.Sequential(
       [
           eqx.nn.Linear(...),
           eqx.nn.Lambda(jax.nn.relu),
           ...
       ]
   )
__init__(self, fn: Callable) ¤

Arguments:

__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input JAX array.
  • key: Ignored.

Returns:

The output of the fn(x) operation.