Skip to content


In addition to this page, note that JAX also has many activation functions built-in, such as jax.nn.relu or jax.nn.softplus.

equinox.nn.PReLU (Module) ¤

PReLU activation function.

This is the elementwise function x -> max(x, 0) + α * min(x, 0). This can be thought of as a leaky ReLU, with a learnt leak α.

__init__(self, init_alpha: Union[float, Array] = 0.25) ¤


  • init_alpha: The initial value \(\alpha\) of the negative slope. This should either be a float (default value is \(0.25\)), or a JAX array of \(\alpha_i\) values. The shape of such a JAX array should be broadcastable to the input.
__call__(self, x: Array) -> Array ¤


  • x: The input.


A JAX array of the same shape as the input.