Normalisation¤
equinox.nn.LayerNorm (Module)
¤
Computes a mean and standard deviation over the whole input array, and uses these to normalise the whole array. Optionally applies an elementwise affine transformation afterwards.
Given an input array \(x\), this layer computes
where \(\gamma\), \(\beta\) have the same shape as \(x\) if elementwise_affine=True
,
and \(\gamma = 1\), \(\beta = 0\) if elementwise_affine=False
.
Cite
@article{ba2016layer,
author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
title={Layer Normalization},
year={2016},
journal={arXiv:1607.06450},
}
FAQ
If you need to normalise over only some input dimensions, then this can be achieved by vmap'ing. For example the following will compute statistics over every dimension except the first:
layer = LayerNorm(...)
array = jax.vmap(layer)(array)
__init__(self, shape: Union[int, Sequence[int]], eps: float = 1e-05, use_weight: bool = True, use_bias: bool = True, dtype = None, *, elementwise_affine: Optional[bool] = None)
¤
Arguments:
shape
: Shape of the input.eps
: Value added to denominator for numerical stability.use_weight
: Whether the module has learnable affine weights.use_bias
: Whether the module has learnable affine biases.dtype
: The dtype to use for the weight and the bias in this layer ifuse_weight
oruse_bias
is set toTrue
. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.elementwise_affine
: Deprecated alternative touse_weight
anduse_bias
.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]
¤
Arguments:
x
: A JAX array, with the same shape as theshape
passed to__init__
.state
: Ignored; provided for interchangeability with theequinox.nn.BatchNorm
API.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The output is a JAX array of the same shape as x
.
If state
is passed, then a 2-tuple of (output, state)
is returned. The state
is passed through unchanged. If state
is not passed, then just the output is
returned.
equinox.nn.RMSNorm (Module)
¤
A simplified version of LayerNorm which rescales the inputs, but does not center them. Optionally applies a learned reweighting of the transformed array afterward.
Given an input array \(x\), this layer computes
where \(\Vert x \Vert^2_2 = \sum_{i=1}^n x_i^2\), \(n = \dim(x)\), and \(\gamma\) is a
learned array with the same shape as \(x\) if use_weight=True
, or
\(\gamma = 1\) if use_weight=False
, as proposed in
this paper. \beta
is an optional bias
term.
Cite
Root Mean Square Layer Normalization
@article{zhang2019root,
title={Root Mean Square Layer Normalization},
author={Biao Zhang and Rico Sennrich},
year={2019},
journal={arXiv:1910.07467}
}
__init__(self, shape: Union[int, Sequence[int]], eps: float = 1e-05, use_weight: bool = True, use_bias: bool = True, dtype = None)
¤
Arguments:
shape
: Shape of the input.eps
: Value added to denominator for numerical stability.use_weight
: Whether the module has learnable affine weights.use_bias
: Whether the module has learnable affine shift.dtype
: The dtype to use for the weight and the bias in this layer ifuse_weight
oruse_bias
is set toTrue
. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]
¤
Arguments:
x
: A JAX array, with the same shape as theshape
passed to__init__
.state
: Ignored; provided for interchangability with theequinox.nn.BatchNorm
API.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The output is a JAX array of the same shape as x
.
If state
is passed, then a 2-tuple of (output, state)
is returned. The state
is passed through unchanged. If state
is not passed, then just the output is
returned.
equinox.nn.GroupNorm (Module)
¤
Splits the first dimension ("channels") into groups of fixed size. Computes a mean and standard deviation over the contents of each group, and uses these to normalise the group. Optionally applies a channel-wise affine transformation afterwards.
Given an input array \(x\) of shape (channels, ...)
, this layer splits this up into
groups
-many arrays \(x_i\) each of shape (channels/groups, ...)
, and for each one
computes
where \(\gamma_i\), \(\beta_i\) have shape (channels/groups,)
if
channelwise_affine=True
, and \(\gamma = 1\), \(\beta = 0\) if
channelwise_affine=False
.
Cite
@article{wu2018group,
author={Yuxin Wu and Kaiming He},
title={Group Normalization},
year={2018},
journal={arXiv:1803.08494},
}
__init__(self, groups: int, channels: Optional[int] = None, eps: float = 1e-05, channelwise_affine: bool = True, dtype = None)
¤
Arguments:
groups
: The number of groups to split the input into.channels
: The number of input channels. May be left unspecified (e.g. justNone
) ifchannelwise_affine=False
.eps
: Value added to denominator for numerical stability.channelwise_affine
: Whether the module has learnable affine parameters.dtype
: The dtype to use for the weight and the bias in this layer ifchannelwise_affine
is set toTrue
. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]
¤
Arguments:
x
: A JAX array of shape(channels, ...)
.state
: Ignored; provided for interchangability with theequinox.nn.BatchNorm
API.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The output is a JAX array of shape (channels, ...)
.
If state
is passed, then a 2-tuple of (output, state)
is returned. The state
is passed through unchanged. If state
is not passed, then just the output is
returned.
equinox.nn.BatchNorm (StatefulLayer)
¤
Computes a mean and standard deviation over the batch and spatial dimensions of an array, and uses these to normalise the whole array. Optionally applies a channelwise affine transformation afterwards.
Given an input array \(x = [x_1, ... x_C]\) with \(C\) channels, this layer computes
for all \(i\). Here \(*\) denotes elementwise multiplication and \(\gamma\), \(\beta\) have
shape \((C,)\) if channelwise_affine=True
and \(\gamma = 1\), \(\beta = 0\) if
channelwise_affine=False
. Expectations are computed over all spatial dimensions
and over the batch dimension, and updated batch-by-batch according to momentum
.
Example
See this example for example usage.
Warning
This layer must be used inside of a vmap
or pmap
with a matching
axis_name
. (Not doing so will raise a NameError
.)
Note that this layer behaves differently during training and inference. During
training then statistics are computed using the input data, and the running
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
equinox.nn.inference_mode
.
__init__(self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], eps: float = 1e-05, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, dtype = None)
¤
Arguments:
input_size
: The number of channels in the input array.axis_name
: The name of the batch axis to compute statistics over, as passed toaxis_name
injax.vmap
orjax.pmap
. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes.eps
: Value added to the denominator for numerical stability.channelwise_affine
: Whether the module has learnable channel-wise affine parameters.momentum
: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive.inference
: IfFalse
then the batch means and variances will be calculated and used to update the running statistics. IfTrue
then the running statistics are directly used for normalisation. This may be toggled withequinox.nn.inference_mode
or overridden duringequinox.nn.BatchNorm.__call__
.dtype
: The dtype to use for the running statistics and the weight and bias ifchannelwise_affine
isTrue
. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None) -> tuple[Array, State]
¤
Arguments:
x
: A JAX array of shape(input_size, dim_1, ..., dim_N)
.state
: Anequinox.nn.State
object (which is used to store the running statistics).key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)inference
: As perequinox.nn.BatchNorm.__init__
. IfTrue
orFalse
then it will take priority overself.inference
. IfNone
then the value fromself.inference
will be used.
Returns:
A 2-tuple of:
- A JAX array of shape
(input_size, dim_1, ..., dim_N)
. - An updated state object (storing the updated running statistics).
Raises:
A NameError
if no vmap
s are placed around this operation, or if this vmap
does not have a matching axis_name
.
equinox.nn.SpectralNorm (StatefulLayer)
¤
Applies spectral normalisation to a given parameter.
Given a weight matrix \(W\), and letting \(σ(W)\) denote (an approximation to) its largest singular value, then this computes \(W/σ(W)\).
The approximation \(σ(W)\) is computed using power iterations which are updated (as a side-effect) every time \(W/σ(W)\) is computed.
Spectral normalisation is particularly commonly used when training generative adversarial networks; see Spectral Normalization for Generative Adversarial Networks for more details and motivation.
Example
See this example for example usage.
Note that this layer behaves differently during training and inference. During
training then power iterations are updated; during inference they are fixed.
Whether the model is in training or inference mode should be toggled using
equinox.nn.inference_mode
.
__init__(self, layer: ~_Layer, weight_name: str, num_power_iterations: int = 1, eps: float = 1e-12, inference: bool = False, *, key: PRNGKeyArray)
¤
Arguments:
layer
: The layer to wrap. Usually aequinox.nn.Linear
or a convolutional layer (e.g.equinox.nn.Conv2d
).weight_name
: The name of the layer's parameter (a JAX array) to apply spectral normalisation to.num_power_iterations
: The number of power iterations to apply every time the array is accessed.eps
: Epsilon for numerical stability when calculating norms.inference
: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled withequinox.nn.inference_mode
.key
: Ajax.random.PRNGKey
used to provide randomness for initialisation. (Keyword only argument.)
Info
The dtype
of the weight array of the layer
input is applied to all
parameters in this layer.
__call__(self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None) -> tuple[Array, State]
¤
Arguments:
x
: A JAX array.state
: Anequinox.nn.State
object (which is used to store the power iterations).key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)inference
: As perequinox.nn.SpectralNorm.__init__
. IfTrue
orFalse
then it will take priority overself.inference
. IfNone
then the value fromself.inference
will be used.
Returns:
A 2-tuple of:
- The JAX array from calling
self.layer(x)
(with spectral normalisation applied). - An updated context object (storing the updated power iterations).
equinox.nn.WeightNorm (Module)
¤
Applies weight normalisation to a given parameter.
Given the 2D weight matrix \(W = (W_{ij})_{ij} \in \mathbb{R}^{\text{out}} \times \mathbb{R}^{\text{in}}\) of a linear layer, then it replaces it with the following reparameterisation:
\(g \frac{v_{ij}}{\lVert v_{i\, \cdot} \rVert} \in \mathbb{R}^{\text{out}} \times \mathbb{R}^{\text{in}}\)
where \(v_{ij}\) is initialised as \(W_{ij}\), and \(g\) is initialised as \(\lVert v_{i\, \cdot} \rVert = \sum_j {v_{ij}}^2\).
Overall, the direction (\(v\)) and the magnitude (\(g\)) of the output of each neuron are treated separately.
Given n-dimensional weight matrices \(W\) (in convolutional layers), then the normalisation is analogusly instead computed over every axis except the first.
Cite
@article{DBLP:journals/corr/SalimansK16,
author = {Tim Salimans and
Diederik P. Kingma},
title = {Weight Normalisation: {A} Simple
Reparameterization to Accelerate
Training of Deep Neural Networks},
journal = {CoRR},
volume = {abs/1602.07868},
year = {2016},
url = {http://arxiv.org/abs/1602.07868},
eprinttype = {arXiv},
eprint = {1602.07868},
timestamp = {Mon, 13 Aug 2018 16:47:07 +0200},
biburl = {https://dblp.org/rec/journals/corr/SalimansK16.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
__init__(self, layer: ~_Layer, weight_name: str = 'weight', axis: Optional[int] = 0)
¤
Arguments:
layer
: The layer to wrap. Usually aequinox.nn.Linear
or a convolutional layer (e.g.equinox.nn.Conv2d
).weight_name
: The name of the layer's parameter (a JAX array) to apply weight normalisation to.axis
: The norm is computed across every axis except this one. IfNone
, compute across every axis.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: A JAX Array.key
: Ignored; provided for compatibility with the rest of the Equinox API.
Returns:
- The JAX array from calling
self.layer(x)
(with weight normalisation applied).