# Normalisation¤

####
```
equinox.nn.LayerNorm (Module)
```

¤

Computes a mean and standard deviation over the whole input array, and uses these to normalise the whole array. Optionally applies an elementwise affine transformation afterwards.

Given an input array \(x\), this layer computes

where \(\gamma\), \(\beta\) have the same shape as \(x\) if `elementwise_affine=True`

,
and \(\gamma = 1\), \(\beta = 0\) if `elementwise_affine=False`

.

## Cite

```
@article{ba2016layer,
author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
title={Layer Normalization},
year={2016},
journal={arXiv:1607.06450},
}
```

FAQ

If you need to normalise over only some input dimensions, then this can be
achieved by vmap'ing. For example the following will compute statistics over
every dimension *except* the first:

```
layer = LayerNorm(...)
array = jax.vmap(layer)(array)
```

#####
`__init__(self, shape: Union[int, Sequence[int]], eps: float = 1e-05, use_weight: bool = True, use_bias: bool = True, dtype = None, *, elementwise_affine: Optional[bool] = None)`

¤

**Arguments:**

`shape`

: Shape of the input.`eps`

: Value added to denominator for numerical stability.`use_weight`

: Whether the module has learnable affine weights.`use_bias`

: Whether the module has learnable affine biases.`dtype`

: The dtype to use for the weight and the bias in this layer if`use_weight`

or`use_bias`

is set to`True`

. Defaults to either`jax.numpy.float32`

or`jax.numpy.float64`

depending on whether JAX is in 64-bit mode.`elementwise_affine`

: Deprecated alternative to`use_weight`

and`use_bias`

.

#####
`__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]`

¤

**Arguments:**

`x`

: A JAX array, with the same shape as the`shape`

passed to`__init__`

.`state`

: Ignored; provided for interchangability with the`equinox.nn.BatchNorm`

API.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

The output is a JAX array of the same shape as `x`

.

If `state`

is passed, then a 2-tuple of `(output, state)`

is returned. The state
is passed through unchanged. If `state`

is not passed, then just the output is
returned.

####
```
equinox.nn.RMSNorm (Module)
```

¤

A simplified version of LayerNorm which rescales the inputs, but does not center them. Optionally applies a learned reweighting of the transformed array afterward.

Given an input array \(x\), this layer computes

where \(\Vert x \Vert^2_2 = \sum_{i=1}^n x_i^2\), \(n = \dim(x)\), and \(\gamma\) is a
learned array with the same shape as \(x\) if `use_weight=True`

, or
\(\gamma = 1\) if `use_weight=False`

, as proposed in
this paper. `\beta`

is an optional bias
term.

## Cite

Root Mean Square Layer Normalization

```
@article{zhang2019root,
title={Root Mean Square Layer Normalization},
author={Biao Zhang and Rico Sennrich},
year={2019},
journal={arXiv:1910.07467}
}
```

#####
`__init__(self, shape: Union[int, Sequence[int]], eps: float = 1e-05, use_weight: bool = True, use_bias: bool = True, dtype = None)`

¤

**Arguments:**

`shape`

: Shape of the input.`eps`

: Value added to denominator for numerical stability.`use_weight`

: Whether the module has learnable affine weights.`use_bias`

: Whether the module has learnable affine shift.`dtype`

: The dtype to use for the weight and the bias in this layer if`use_weight`

or`use_bias`

is set to`True`

. Defaults to either`jax.numpy.float32`

or`jax.numpy.float64`

depending on whether JAX is in 64-bit mode.

#####
`__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]`

¤

**Arguments:**

`x`

: A JAX array, with the same shape as the`shape`

passed to`__init__`

.`state`

: Ignored; provided for interchangability with the`equinox.nn.BatchNorm`

API.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

The output is a JAX array of the same shape as `x`

.

If `state`

is passed, then a 2-tuple of `(output, state)`

is returned. The state
is passed through unchanged. If `state`

is not passed, then just the output is
returned.

####
```
equinox.nn.GroupNorm (Module)
```

¤

Splits the first dimension ("channels") into groups of fixed size. Computes a mean and standard deviation over the contents of each group, and uses these to normalise the group. Optionally applies a channel-wise affine transformation afterwards.

Given an input array \(x\) of shape `(channels, ...)`

, this layer splits this up into
`groups`

-many arrays \(x_i\) each of shape `(channels/groups, ...)`

, and for each one
computes

where \(\gamma_i\), \(\beta_i\) have shape `(channels/groups,)`

if
`channelwise_affine=True`

, and \(\gamma = 1\), \(\beta = 0\) if
`channelwise_affine=False`

.

## Cite

```
@article{wu2018group,
author={Yuxin Wu and Kaiming He},
title={Group Normalization},
year={2018},
journal={arXiv:1803.08494},
}
```

#####
`__init__(self, groups: int, channels: Optional[int] = None, eps: float = 1e-05, channelwise_affine: bool = True, dtype = None)`

¤

**Arguments:**

`groups`

: The number of groups to split the input into.`channels`

: The number of input channels. May be left unspecified (e.g. just`None`

) if`channelwise_affine=False`

.`eps`

: Value added to denominator for numerical stability.`channelwise_affine`

: Whether the module has learnable affine parameters.`dtype`

: The dtype to use for the weight and the bias in this layer if`channelwise_affine`

is set to`True`

. Defaults to either`jax.numpy.float32`

or`jax.numpy.float64`

depending on whether JAX is in 64-bit mode.

#####
`__call__(self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None) -> Union[Array, tuple[Array, State]]`

¤

**Arguments:**

`x`

: A JAX array of shape`(channels, ...)`

.`state`

: Ignored; provided for interchangability with the`equinox.nn.BatchNorm`

API.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

The output is a JAX array of shape `(channels, ...)`

.

If `state`

is passed, then a 2-tuple of `(output, state)`

is returned. The state
is passed through unchanged. If `state`

is not passed, then just the output is
returned.

####
```
equinox.nn.BatchNorm (StatefulLayer)
```

¤

Computes a mean and standard deviation over the batch and spatial dimensions of an array, and uses these to normalise the whole array. Optionally applies a channelwise affine transformation afterwards.

Given an input array \(x = [x_1, ... x_C]\) with \(C\) channels, this layer computes

for all \(i\). Here \(*\) denotes elementwise multiplication and \(\gamma\), \(\beta\) have
shape \((C,)\) if `channelwise_affine=True`

and \(\gamma = 1\), \(\beta = 0\) if
`channelwise_affine=False`

. Expectations are computed over all spatial dimensions
*and* over the batch dimension, and updated batch-by-batch according to `momentum`

.

Example

See this example for example usage.

Warning

This layer must be used inside of a `vmap`

or `pmap`

with a matching
`axis_name`

. (Not doing so will raise a `NameError`

.)

Note that this layer behaves differently during training and inference. During
training then statistics are computed using the input data, and the running
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
`equinox.nn.inference_mode`

.

#####
`__init__(self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], eps: float = 1e-05, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, dtype = None)`

¤

**Arguments:**

`input_size`

: The number of channels in the input array.`axis_name`

: The name of the batch axis to compute statistics over, as passed to`axis_name`

in`jax.vmap`

or`jax.pmap`

. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes.`eps`

: Value added to the denominator for numerical stability.`channelwise_affine`

: Whether the module has learnable channel-wise affine parameters.`momentum`

: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive.`inference`

: If`False`

then the batch means and variances will be calculated and used to update the running statistics. If`True`

then the running statistics are directly used for normalisation. This may be toggled with`equinox.nn.inference_mode`

or overridden during`equinox.nn.BatchNorm.__call__`

.`dtype`

: The dtype to use for the running statistics and the weight and bias if`channelwise_affine`

is`True`

. Defaults to either`jax.numpy.float32`

or`jax.numpy.float64`

depending on whether JAX is in 64-bit mode.

#####
`__call__(self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None) -> tuple[Array, State]`

¤

**Arguments:**

`x`

: A JAX array of shape`(input_size, dim_1, ..., dim_N)`

.`state`

: An`equinox.nn.State`

object (which is used to store the running statistics).`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)`inference`

: As per`equinox.nn.BatchNorm.__init__`

. If`True`

or`False`

then it will take priority over`self.inference`

. If`None`

then the value from`self.inference`

will be used.

**Returns:**

A 2-tuple of:

- A JAX array of shape
`(input_size, dim_1, ..., dim_N)`

. - An updated state object (storing the updated running statistics).

**Raises:**

A `NameError`

if no `vmap`

s are placed around this operation, or if this vmap
does not have a matching `axis_name`

.

####
```
equinox.nn.SpectralNorm (StatefulLayer)
```

¤

Applies spectral normalisation to a given parameter.

Given a weight matrix \(W\), and letting \(σ(W)\) denote (an approximation to) its largest singular value, then this computes \(W/σ(W)\).

The approximation \(σ(W)\) is computed using power iterations which are updated (as a side-effect) every time \(W/σ(W)\) is computed.

Spectral normalisation is particularly commonly used when training generative adversarial networks; see Spectral Normalization for Generative Adversarial Networks for more details and motivation.

Example

See this example for example usage.

Note that this layer behaves differently during training and inference. During
training then power iterations are updated; during inference they are fixed.
Whether the model is in training or inference mode should be toggled using
`equinox.nn.inference_mode`

.

#####
`__init__(self, layer: ~_Layer, weight_name: str, num_power_iterations: int = 1, eps: float = 1e-12, inference: bool = False, *, key: PRNGKeyArray)`

¤

**Arguments:**

`layer`

: The layer to wrap. Usually a`equinox.nn.Linear`

or a convolutional layer (e.g.`equinox.nn.Conv2d`

).`weight_name`

: The name of the layer's parameter (a JAX array) to apply spectral normalisation to.`num_power_iterations`

: The number of power iterations to apply every time the array is accessed.`eps`

: Epsilon for numerical stability when calculating norms.`inference`

: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled with`equinox.nn.inference_mode`

.`key`

: A`jax.random.PRNGKey`

used to provide randomness for initialisation. (Keyword only argument.)

Info

The `dtype`

of the weight array of the `layer`

input is applied to all
parameters in this layer.

#####
`__call__(self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None) -> tuple[Array, State]`

¤

**Arguments:**

`x`

: A JAX array.`state`

: An`equinox.nn.State`

object (which is used to store the power iterations).`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)`inference`

: As per`equinox.nn.SpectralNorm.__init__`

. If`True`

or`False`

then it will take priority over`self.inference`

. If`None`

then the value from`self.inference`

will be used.

**Returns:**

A 2-tuple of:

- The JAX array from calling
`self.layer(x)`

(with spectral normalisation applied). - An updated context object (storing the updated power iterations).

####
```
equinox.nn.WeightNorm (Module)
```

¤

Applies weight normalisation to a given parameter.

Given the 2D weight matrix \(W = (W_{ij})_{ij} \in \mathbb{R}^{\text{out}} \times \mathbb{R}^{\text{in}}\) of a linear layer, then it replaces it with the following reparameterisation:

\(g \frac{v_{ij}}{\lVert v_{i\, \cdot} \rVert} \in \mathbb{R}^{\text{out}} \times \mathbb{R}^{\text{in}}\)

where \(v_{ij}\) is initialised as \(W_{ij}\), and \(g\) is initialised as \(\lVert v_{i\, \cdot} \rVert = \sum_j {v_{ij}}^2\).

Overall, the direction (\(v\)) and the magnitude (\(g\)) of the output of each neuron are treated separately.

Given n-dimensional weight matrices \(W\) (in convolutional layers), then the normalisation is analogusly instead computed over every axis except the first.

## Cite

```
@article{DBLP:journals/corr/SalimansK16,
author = {Tim Salimans and
Diederik P. Kingma},
title = {Weight Normalisation: {A} Simple
Reparameterization to Accelerate
Training of Deep Neural Networks},
journal = {CoRR},
volume = {abs/1602.07868},
year = {2016},
url = {http://arxiv.org/abs/1602.07868},
eprinttype = {arXiv},
eprint = {1602.07868},
timestamp = {Mon, 13 Aug 2018 16:47:07 +0200},
biburl = {https://dblp.org/rec/journals/corr/SalimansK16.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```

#####
`__init__(self, layer: ~_Layer, weight_name: str = 'weight', axis: Optional[int] = 0)`

¤

**Arguments:**

`layer`

: The layer to wrap. Usually a`equinox.nn.Linear`

or a convolutional layer (e.g.`equinox.nn.Conv2d`

).`weight_name`

: The name of the layer's parameter (a JAX array) to apply weight normalisation to.`axis`

: The norm is computed across every axis except this one. If`None`

, compute across every axis.

#####
`__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array`

¤

**Arguments:**

`x`

: A JAX Array.`key`

: Ignored; provided for compatibility with the rest of the Equinox API.

**Returns:**

- The JAX array from calling
`self.layer(x)`

(with weight normalisation applied).