Skip to content

Normalisation¤

equinox.nn.LayerNorm (Module) ¤

Computes a mean and standard deviation over the whole input array, and uses these to normalise the whole array. Optionally applies an elementwise affine transformation afterwards.

Given an input array \(x\), this layer computes

\[\frac{x - \mathbb{E}[x]}{\sqrt{\text{Var}[x] + \varepsilon}} * \gamma + \beta\]

where \(\gamma\), \(\beta\) have the same shape as \(x\) if elementwise_affine=True, and \(\gamma = 1\), \(\beta = 0\) if elementwise_affine=False.

Cite

Layer Normalization

@article{ba2016layer,
    author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
    title={Layer Normalization},
    year={2016},
    journal={arXiv:1607.06450},
}

FAQ

If you need to normalise over only some input dimensions, then this can be achieved by vmap'ing. For example the following will compute statistics over every dimension except the first:

layer = LayerNorm(...)
array = jax.vmap(layer)(array)

__init__(self, shape: Union[int, Sequence[int]], eps: float = 1e-05, use_weight: bool = True, use_bias: bool = True, dtype = None, *, elementwise_affine: Optional[bool] = None) ¤

Arguments:

  • shape: Shape of the input.
  • eps: Value added to denominator for numerical stability.
  • use_weight: Whether the module has learnable affine weights.
  • use_bias: Whether the module has learnable affine biases.
  • dtype: The dtype to use for the weight and the bias in this layer if use_weight or use_bias is set to True. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • elementwise_affine: Deprecated alternative to use_weight and use_bias.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]] ¤

Arguments:

  • x: A JAX array, with the same shape as the shape passed to __init__.
  • state: Ignored; provided for interchangability with the equinox.nn.BatchNorm API.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The output is a JAX array of the same shape as x.

If state is passed, then a 2-tuple of (output, state) is returned. The state is passed through unchanged. If state is not passed, then just the output is returned.


equinox.nn.RMSNorm (Module) ¤

A simplified version of LayerNorm which rescales the inputs, but does not center them. Optionally applies a learned reweighting of the transformed array afterward.

Given an input array \(x\), this layer computes

\[\frac{x}{\sqrt{\varepsilon + \frac{1}{n}\Vert x \Vert^2_2}} \gamma + \beta\]

where \(\Vert x \Vert^2_2 = \sum_{i=1}^n x_i^2\), \(n = \dim(x)\), and \(\gamma\) is a learned array with the same shape as \(x\) if use_weight=True, or \(\gamma = 1\) if use_weight=False, as proposed in this paper. \beta is an optional bias term.

Cite

Root Mean Square Layer Normalization

@article{zhang2019root,
    title={Root Mean Square Layer Normalization},
    author={Biao Zhang and Rico Sennrich},
    year={2019},
    journal={arXiv:1910.07467}
}
__init__(self, shape: Union[int, Sequence[int]], eps: float = 1e-05, use_weight: bool = True, use_bias: bool = True, dtype = None) ¤

Arguments:

  • shape: Shape of the input.
  • eps: Value added to denominator for numerical stability.
  • use_weight: Whether the module has learnable affine weights.
  • use_bias: Whether the module has learnable affine shift.
  • dtype: The dtype to use for the weight and the bias in this layer if use_weight or use_bias is set to True. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]] ¤

Arguments:

  • x: A JAX array, with the same shape as the shape passed to __init__.
  • state: Ignored; provided for interchangability with the equinox.nn.BatchNorm API.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The output is a JAX array of the same shape as x.

If state is passed, then a 2-tuple of (output, state) is returned. The state is passed through unchanged. If state is not passed, then just the output is returned.


equinox.nn.GroupNorm (Module) ¤

Splits the first dimension ("channels") into groups of fixed size. Computes a mean and standard deviation over the contents of each group, and uses these to normalise the group. Optionally applies a channel-wise affine transformation afterwards.

Given an input array \(x\) of shape (channels, ...), this layer splits this up into groups-many arrays \(x_i\) each of shape (channels/groups, ...), and for each one computes

\[\frac{x_i - \mathbb{E}[x_i]}{\sqrt{\text{Var}[x_i] + \varepsilon}} * \gamma_i + \beta_i\]

where \(\gamma_i\), \(\beta_i\) have shape (channels/groups,) if channelwise_affine=True, and \(\gamma = 1\), \(\beta = 0\) if channelwise_affine=False.

Cite

Group Normalization

@article{wu2018group,
    author={Yuxin Wu and Kaiming He},
    title={Group Normalization},
    year={2018},
    journal={arXiv:1803.08494},
}
__init__(self, groups: int, channels: Optional[int] = None, eps: float = 1e-05, channelwise_affine: bool = True, dtype = None) ¤

Arguments:

  • groups: The number of groups to split the input into.
  • channels: The number of input channels. May be left unspecified (e.g. just None) if channelwise_affine=False.
  • eps: Value added to denominator for numerical stability.
  • channelwise_affine: Whether the module has learnable affine parameters.
  • dtype: The dtype to use for the weight and the bias in this layer if channelwise_affine is set to True. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]] ¤

Arguments:

  • x: A JAX array of shape (channels, ...).
  • state: Ignored; provided for interchangability with the equinox.nn.BatchNorm API.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The output is a JAX array of shape (channels, ...).

If state is passed, then a 2-tuple of (output, state) is returned. The state is passed through unchanged. If state is not passed, then just the output is returned.


equinox.nn.BatchNorm (StatefulLayer) ¤

Computes a mean and standard deviation over the batch and spatial dimensions of an array, and uses these to normalise the whole array. Optionally applies a channelwise affine transformation afterwards.

Given an input array \(x = [x_1, ... x_C]\) with \(C\) channels, this layer computes

\[\frac{x_i - \mathbb{E}[x_i]}{\sqrt{\text{Var}[x_i] + \varepsilon}} * \gamma_i + \beta_i\]

for all \(i\). Here \(*\) denotes elementwise multiplication and \(\gamma\), \(\beta\) have shape \((C,)\) if channelwise_affine=True and \(\gamma = 1\), \(\beta = 0\) if channelwise_affine=False. Expectations are computed over all spatial dimensions and over the batch dimension, and updated batch-by-batch according to momentum.

Example

See this example for example usage.

Warning

This layer must be used inside of a vmap or pmap with a matching axis_name. (Not doing so will raise a NameError.)

Note that this layer behaves differently during training and inference. During training then statistics are computed using the input data, and the running statistics updated. During inference then just the running statistics are used. Whether the model is in training or inference mode should be toggled using equinox.nn.inference_mode.

__init__(self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], eps: float = 1e-05, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, dtype = None) ¤

Arguments:

  • input_size: The number of channels in the input array.
  • axis_name: The name of the batch axis to compute statistics over, as passed to axis_name in jax.vmap or jax.pmap. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes.
  • eps: Value added to the denominator for numerical stability.
  • channelwise_affine: Whether the module has learnable channel-wise affine parameters.
  • momentum: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive.
  • inference: If False then the batch means and variances will be calculated and used to update the running statistics. If True then the running statistics are directly used for normalisation. This may be toggled with equinox.nn.inference_mode or overridden during equinox.nn.BatchNorm.__call__.
  • dtype: The dtype to use for the running statistics and the weight and bias if channelwise_affine is True. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None) -> tuple[Array, State] ¤

Arguments:

  • x: A JAX array of shape (input_size, dim_1, ..., dim_N).
  • state: An equinox.nn.State object (which is used to store the running statistics).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
  • inference: As per equinox.nn.BatchNorm.__init__. If True or False then it will take priority over self.inference. If None then the value from self.inference will be used.

Returns:

A 2-tuple of:

  • A JAX array of shape (input_size, dim_1, ..., dim_N).
  • An updated state object (storing the updated running statistics).

Raises:

A NameError if no vmaps are placed around this operation, or if this vmap does not have a matching axis_name.


equinox.nn.SpectralNorm (StatefulLayer) ¤

Applies spectral normalisation to a given parameter.

Given a weight matrix \(W\), and letting \(σ(W)\) denote (an approximation to) its largest singular value, then this computes \(W/σ(W)\).

The approximation \(σ(W)\) is computed using power iterations which are updated (as a side-effect) every time \(W/σ(W)\) is computed.

Spectral normalisation is particularly commonly used when training generative adversarial networks; see Spectral Normalization for Generative Adversarial Networks for more details and motivation.

Example

See this example for example usage.

Note that this layer behaves differently during training and inference. During training then power iterations are updated; during inference they are fixed. Whether the model is in training or inference mode should be toggled using equinox.nn.inference_mode.

__init__(self, layer: ~_Layer, weight_name: str, num_power_iterations: int = 1, eps: float = 1e-12, inference: bool = False, *, key: PRNGKeyArray) ¤

Arguments:

  • layer: The layer to wrap. Usually a equinox.nn.Linear or a convolutional layer (e.g. equinox.nn.Conv2d).
  • weight_name: The name of the layer's parameter (a JAX array) to apply spectral normalisation to.
  • num_power_iterations: The number of power iterations to apply every time the array is accessed.
  • eps: Epsilon for numerical stability when calculating norms.
  • inference: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled with equinox.nn.inference_mode.
  • key: A jax.random.PRNGKey used to provide randomness for initialisation. (Keyword only argument.)

Info

The dtype of the weight array of the layer input is applied to all parameters in this layer.

__call__(self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None) -> tuple[Array, State] ¤

Arguments:

  • x: A JAX array.
  • state: An equinox.nn.State object (which is used to store the power iterations).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
  • inference: As per equinox.nn.SpectralNorm.__init__. If True or False then it will take priority over self.inference. If None then the value from self.inference will be used.

Returns:

A 2-tuple of:

  • The JAX array from calling self.layer(x) (with spectral normalisation applied).
  • An updated context object (storing the updated power iterations).

equinox.nn.WeightNorm (Module) ¤

Applies weight normalisation to a given parameter.

Given the 2D weight matrix \(W = (W_{ij})_{ij} \in \mathbb{R}^{\text{out}} \times \mathbb{R}^{\text{in}}\) of a linear layer, then it replaces it with the following reparameterisation:

\(g \frac{v_{ij}}{\lVert v_{i\, \cdot} \rVert} \in \mathbb{R}^{\text{out}} \times \mathbb{R}^{\text{in}}\)

where \(v_{ij}\) is initialised as \(W_{ij}\), and \(g\) is initialised as \(\lVert v_{i\, \cdot} \rVert = \sum_j {v_{ij}}^2\).

Overall, the direction (\(v\)) and the magnitude (\(g\)) of the output of each neuron are treated separately.

Given n-dimensional weight matrices \(W\) (in convolutional layers), then the normalisation is analogusly instead computed over every axis except the first.

Cite

Weight Normalisation

@article{DBLP:journals/corr/SalimansK16,
author       = {Tim Salimans and
                Diederik P. Kingma},
title        = {Weight Normalisation: {A} Simple
                Reparameterization to Accelerate
                Training of Deep Neural Networks},
journal      = {CoRR},
volume       = {abs/1602.07868},
year         = {2016},
url          = {http://arxiv.org/abs/1602.07868},
eprinttype   = {arXiv},
eprint       = {1602.07868},
timestamp    = {Mon, 13 Aug 2018 16:47:07 +0200},
biburl       = {https://dblp.org/rec/journals/corr/SalimansK16.bib},
bibsource    = {dblp computer science bibliography, https://dblp.org}
}
__init__(self, layer: ~_Layer, weight_name: str = 'weight', axis: Optional[int] = 0) ¤

Arguments:

  • layer: The layer to wrap. Usually a equinox.nn.Linear or a convolutional layer (e.g. equinox.nn.Conv2d).
  • weight_name: The name of the layer's parameter (a JAX array) to apply weight normalisation to.
  • axis: The norm is computed across every axis except this one. If None, compute across every axis.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: A JAX Array.
  • key: Ignored; provided for compatibility with the rest of the Equinox API.

Returns:

  • The JAX array from calling self.layer(x) (with weight normalisation applied).