Skip to content

Linear¤

equinox.nn.Linear (Module) ¤

Performs a linear transformation.

__init__(self, in_features: Union[int, Literal['scalar']], out_features: Union[int, Literal['scalar']], use_bias: bool = True, dtype = None, *, key: PRNGKeyArray) ¤

Arguments:

  • in_features: The input size. The input to the layer should be a vector of shape (in_features,)
  • out_features: The output size. The output from the layer will be a vector of shape (out_features,).
  • use_bias: Whether to add on a bias as well.
  • dtype: The dtype to use for the weight and the bias in this layer. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

Note that in_features also supports the string "scalar" as a special value. In this case the input to the layer should be of shape ().

Likewise out_features can also be a string "scalar", in which case the output from the layer will have shape ().

__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: The input. Should be a JAX array of shape (in_features,). (Or shape () if in_features="scalar".)
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Info

If you want to use higher order tensors as inputs (for example featuring " "batch dimensions) then use jax.vmap. For example, for an input x of " "shape (batch, in_features), using

linear = equinox.nn.Linear(...)
jax.vmap(linear)(x)
will produce the appropriate output of shape (batch, out_features).

Returns:

A JAX array of shape (out_features,). (Or shape () if out_features="scalar".)


equinox.nn.Identity (Module) ¤

Identity operation that does nothing. Sometimes useful as a placeholder for another Module.

__init__(self, *args: Any, **kwargs: Any) ¤

Consumes arbitrary *args and **kwargs but ignores them.

__call__(self, x: ~_T, *, key: Optional[PRNGKeyArray] = None) -> ~_T ¤

Arguments:

  • x: The input, of any type.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The input, unchanged.