Skip to content

Activations¤

In addition to this page, note that JAX also has many activation functions built-in, such as jax.nn.relu or jax.nn.softplus.


equinox.nn.PReLU(equinox.Module) ¤

PReLU activation function.

This is the elementwise function x -> max(x, 0) + α * min(x, 0). This can be thought of as a leaky ReLU, with a learnt leak α.

__init__(init_alpha: float | Array | None = 0.25) ¤

Arguments:

  • init_alpha: The initial value \(\alpha\) of the negative slope. This should either be a float (default value is \(0.25\)), or a JAX array of \(\alpha_i\) values. The shape of such a JAX array should be broadcastable to the input.
__call__(x: Array) -> Array ¤

Arguments:

  • x: The input.

Returns:

A JAX array of the same shape as the input.