Linear¤
equinox.nn.Linear (Module)
¤
Performs a linear transformation.
__init__(self, in_features: int, out_features: int, use_bias: bool = True, *, key: jax.random.PRNGKey)
¤
Arguments:
in_features
: The input size.out_features
: The output size.use_bias
: Whether to add on a bias as well.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(in_features,)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Info
If you want to use higher order tensors as inputs (for example featuring batch dimensions) then use
jax.vmap
. For example, for an input x
of shape (batch, in_features)
, using
linear = equinox.nn.Linear(...)
jax.vmap(linear)(x)
(batch, out_features)
.
Returns:
A JAX array of shape (out_features,)
equinox.nn.Identity (Module)
¤
Identity operation that does nothing. Sometimes useful as a placeholder for another Module.
__init__(self, *args, **kwargs)
¤
Consumes arbitrary *args
and **kwargs
but ignores them.
__call__(self, x: ~T, *, key: Optional[jax.random.PRNGKey] = None) -> ~T
¤
Arguments:
x
: The input, of any type.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The input, unchanged.