Multi-layer perceptron¤
equinox.nn.MLP (Module)
¤
Standard Multi-Layer Perceptron; also known as a feed-forward network.
Faq
If you get a TypeError saying an object is not a valid JAX type, see the FAQ.
__init__(self, in_size: Union[int, Literal['scalar']], out_size: Union[int, Literal['scalar']], width_size: int, depth: int, activation: Callable = <function relu>, final_activation: Callable = lambda x: x, use_bias: bool = True, use_final_bias: bool = True, dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
in_size
: The input size. The input to the module should be a vector of shape(in_features,)
out_size
: The output size. The output from the module will be a vector of shape(out_features,)
.width_size
: The size of each hidden layer.depth
: The number of hidden layers, including the output layer. For example,depth=2
results in an network with layers: [Linear(in_size, width_size)
,Linear(width_size, width_size)
,Linear(width_size, out_size)
].activation
: The activation function after each hidden layer. Defaults to ReLU.final_activation
: The activation function after the output layer. Defaults to the identity.use_bias
: Whether to add on a bias to internal layers. Defaults toTrue
.use_final_bias
: Whether to add on a bias to the final layer. Defaults toTrue
.dtype
: The dtype to use for all the weights and biases in this MLP. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
Note that in_size
also supports the string "scalar"
as a special value.
In this case the input to the module should be of shape ()
.
Likewise out_size
can also be a string "scalar"
, in which case the
output from the module will have shape ()
.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: A JAX array with shape(in_size,)
. (Or shape()
ifin_size="scalar"
.)key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array with shape (out_size,)
. (Or shape ()
if out_size="scalar"
.)