Skip to content

Multi-layer perceptron¤

equinox.nn.MLP (Module) ¤

Standard Multi-Layer Perceptron; also known as a feed-forward network.


If you get a TypeError saying an object is not a valid JAX type, see the FAQ.

__init__(self, in_size: Union[int, Literal['scalar']], out_size: Union[int, Literal['scalar']], width_size: int, depth: int, activation: Callable = <function relu>, final_activation: Callable = lambda x: x, use_bias: bool = True, use_final_bias: bool = True, dtype = None, *, key: PRNGKeyArray) ¤


  • in_size: The input size. The input to the module should be a vector of shape (in_features,)
  • out_size: The output size. The output from the module will be a vector of shape (out_features,).
  • width_size: The size of each hidden layer.
  • depth: The number of hidden layers, including the output layer. For example, depth=2 results in an network with layers: [Linear(in_size, width_size), Linear(width_size, width_size), Linear(width_size, out_size)].
  • activation: The activation function after each hidden layer. Defaults to ReLU.
  • final_activation: The activation function after the output layer. Defaults to the identity.
  • use_bias: Whether to add on a bias to internal layers. Defaults to True.
  • use_final_bias: Whether to add on a bias to the final layer. Defaults to True.
  • dtype: The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Note that in_size also supports the string "scalar" as a special value. In this case the input to the module should be of shape ().

Likewise out_size can also be a string "scalar", in which case the output from the module will have shape ().

__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤


  • x: A JAX array with shape (in_size,). (Or shape () if in_size="scalar".)
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)


A JAX array with shape (out_size,). (Or shape () if out_size="scalar".)