Linear¤
equinox.nn.Linear (Module)
¤
Performs a linear transformation.
__init__(self, in_features: Union[int, Literal['scalar']], out_features: Union[int, Literal['scalar']], use_bias: bool = True, dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
in_features
: The input size. The input to the layer should be a vector of shape(in_features,)
out_features
: The output size. The output from the layer will be a vector of shape(out_features,)
.use_bias
: Whether to add on a bias as well.dtype
: The dtype to use for the weight and the bias in this layer. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
Note that in_features
also supports the string "scalar"
as a special value.
In this case the input to the layer should be of shape ()
.
Likewise out_features
can also be a string "scalar"
, in which case the
output from the layer will have shape ()
.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: The input. Should be a JAX array of shape(in_features,)
. (Or shape()
ifin_features="scalar"
.)key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Info
If you want to use higher order tensors as inputs (for example featuring "
"batch dimensions) then use jax.vmap
. For example, for an input x
of "
"shape (batch, in_features)
, using
linear = equinox.nn.Linear(...)
jax.vmap(linear)(x)
(batch, out_features)
.
Returns:
A JAX array of shape (out_features,)
. (Or shape ()
if
out_features="scalar"
.)
equinox.nn.Identity (Module)
¤
Identity operation that does nothing. Sometimes useful as a placeholder for another Module.
__init__(self, *args: Any, **kwargs: Any)
¤
Consumes arbitrary *args
and **kwargs
but ignores them.
__call__(self, x: ~_T, *, key: Optional[PRNGKeyArray] = None) -> ~_T
¤
Arguments:
x
: The input, of any type.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The input, unchanged.