Multi-layer perceptron¤
equinox.nn.MLP(equinox.Module)
¤
Standard Multi-Layer Perceptron; also known as a feed-forward network.
Faq
If you get a TypeError saying an object is not a valid JAX type, see the FAQ.
__init__(in_size: int | Literal['scalar'], out_size: int | Literal['scalar'], width_size: int, depth: int, activation: Callable = <function relu>, final_activation: Callable = lambda x: x, use_bias: bool = True, use_final_bias: bool = True, dtype=None, *, key: PRNGKeyArray)
¤
Arguments:
in_size: The input size. The input to the module should be a vector of shape(in_features,)out_size: The output size. The output from the module will be a vector of shape(out_features,).width_size: The size of each hidden layer.depth: The number of hidden layers, including the output layer. For example,depth=2results in an network with layers: [Linear(in_size, width_size),Linear(width_size, width_size),Linear(width_size, out_size)].activation: The activation function after each hidden layer. Defaults to ReLU.final_activation: The activation function after the output layer. Defaults to the identity.use_bias: Whether to add on a bias to internal layers. Defaults toTrue.use_final_bias: Whether to add on a bias to the final layer. Defaults toTrue.dtype: The dtype to use for all the weights and biases in this MLP. Defaults to eitherjax.numpy.float32orjax.numpy.float64depending on whether JAX is in 64-bit mode.key: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
Note that in_size also supports the string "scalar" as a special value.
In this case the input to the module should be of shape ().
Likewise out_size can also be a string "scalar", in which case the
output from the module will have shape ().
__call__(x: Array, *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x: A JAX array with shape(in_size,). (Or shape()ifin_size="scalar".)key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array with shape (out_size,). (Or shape () if out_size="scalar".)