Standard Multi-Layer Perceptron; also known as a feed-forward network.
If you get a TypeError saying an object is not a valid JAX type, see the FAQ.
__init__(self, in_size: Union[int, Literal['scalar']], out_size: Union[int, Literal['scalar']], width_size: int, depth: int, activation: Callable = <function relu>, final_activation: Callable = lambda x: x, use_bias: bool = True, use_final_bias: bool = True, *, key: PRNGKeyArray, **kwargs)
in_size: The input size. The input to the module should be a vector of shape
out_size: The output size. The output from the module will be a vector of shape
width_size: The size of each hidden layer.
depth: The number of hidden layers, including the output layer. For example,
depth=2results in an network with layers: [
activation: The activation function after each hidden layer. Defaults to ReLU.
final_activation: The activation function after the output layer. Defaults to the identity.
use_bias: Whether to add on a bias to internal layers. Defaults to
use_final_bias: Whether to add on a bias to the final layer. Defaults to
jax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
in_size also supports the string
"scalar" as a special value.
In this case the input to the module should be of shape
out_size can also be a string
"scalar", in which case the
output from the module will have shape
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
x: A JAX array with shape
(in_size,). (Or shape
key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
A JAX array with shape
(out_size,). (Or shape