Skip to content

Linear¤

equinox.nn.Linear (Module) ¤

Performs a linear transformation.

__init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jax.random.PRNGKey) ¤

Arguments:

  • in_features: The input size.
  • out_features: The output size.
  • use_bias: Whether to add on a bias as well.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (in_features,).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Info

If you want to use higher order tensors as inputs (for example featuring batch dimensions) then use jax.vmap. For example, for an input x of shape (batch, in_features), using

linear = equinox.nn.Linear(...)
jax.vmap(linear)(x)
will produce the appropriate output of shape (batch, out_features).

Returns:

A JAX array of shape (out_features,)


equinox.nn.Identity (Module) ¤

Identity operation that does nothing. Sometimes useful as a placeholder for another Module.

__init__(self, *args, **kwargs) ¤

Consumes arbitrary *args and **kwargs but ignores them.

__call__(self, x: ~T, *, key: Optional[jax.random.PRNGKey] = None) -> ~T ¤

Arguments:

  • x: The input, of any type.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The input, unchanged.