# Normalisation¤

####
```
equinox.nn.LayerNorm (Module)
```

¤

Computes a mean and standard deviation over the whole input array, and uses these to normalise the whole array. Optionally applies an elementwise affine transformation afterwards.

Given an input array \(x\), this layer computes

where \(\gamma\), \(\beta\) have the same shape as \(x\) if `elementwise_affine=True`

,
and \(\gamma = 1\), \(\beta = 0\) if `elementwise_affine=False`

.

## Cite

```
@article{ba2016layer,
author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
title={Layer Normalization},
year={2016},
journal={arXiv:1607.06450},
}
```

FAQ

If you need to normalise over only some input dimensions, then this can be
achieved by vmap'ing. For example the following will compute statistics over
every dimension *except* the first:

```
layer = LayerNorm(...)
array = jax.vmap(layer)(array)
```

#####
`__init__(self, shape: Union[NoneType, int, Sequence[int]], eps: float = 1e-05, elementwise_affine: bool = True, **kwargs)`

¤

**Arguments:**

`shape`

: Input shape. May be left unspecified (e.g. just`None`

) if`elementwise_affine=False`

.`eps`

: Value added to denominator for numerical stability.`elementwise_affine`

: Whether the module has learnable affine parameters.

#####
`__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array`

¤

**Arguments:**

`x`

: A JAX array whose shape is given by`shape`

.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

A JAX array of shape `shape`

.

####
```
equinox.nn.GroupNorm (Module)
```

¤

Splits the first dimension ("channels") into groups of fixed size. Computes a mean and standard deviation over the contents of each group, and uses these to normalise the group. Optionally applies a channel-wise affine transformation afterwards.

Given an input array \(x\) of shape `(channels, ...)`

, this layer splits this up into
`groups`

-many arrays \(x_i\) each of shape `(channels/groups, ...)`

, and for each one
computes

where \(\gamma_i\), \(\beta_i\) have shape `(channels/groups,)`

if
`channelwise_affine=True`

, and \(\gamma = 1\), \(\beta = 0\) if
`channelwise_affine=False`

.

## Cite

```
@article{wu2018group,
author={Yuxin Wu and Kaiming He},
title={Group Normalization},
year={2018},
journal={arXiv:1803.08494},
}
```

#####
`__init__(self, groups: int, channels: Optional[int] = None, eps: float = 1e-05, channelwise_affine: bool = True, **kwargs)`

¤

**Arguments:**

`groups`

: The number of groups to split the input into.`channels`

: The number of input channels. May be left unspecified (e.g. just`None`

) if`channelwise_affine=False`

.`eps`

: Value added to denominator for numerical stability.`channelwise_affine`

: Whether the module has learnable affine parameters.

#####
`__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array`

¤

**Arguments:**

`x`

: A JAX array of shape`(channels, ...)`

.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

**Returns:**

A JAX array of shape `(channels, ...)`

.

####
```
equinox.experimental.BatchNorm (Module)
```

¤

Computes a mean and standard deviation over the batch and spatial dimensions of an array, and uses these to normalise the whole array. Optionally applies a channelwise affine transformation afterwards.

Given an input array \(x = [x_1, ... x_C]\) with \(C\) channels, this layer computes

for all \(i\). Here \(*\) denotes elementwise multiplication and \(\gamma\), \(\beta\) have
shape \((C,)\) if `channelwise_affine=True`

and \(\gamma = 1\), \(\beta = 0\) if
`channelwise_affine=False`

. Expectations are computed over all spatial dimensions
*and* over the batch dimension, and updated batch-by-batch according to `momentum`

.

Example

```
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# BatchNorm will automatically update its running statistics internally.
```

Warning

This layer must be used inside of a `vmap`

or `pmap`

with a matching
`axis_name`

. (Not doing so will raise a `NameError`

.)

Warning

`equinox.experimental.BatchNorm`

updates its running statistics as a side
effect of its forward pass. Side effects are quite unusual in JAX; as such
`BatchNorm`

is considered experimental. Let us know how you find it!

#####
`__init__(self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], eps: float = 1e-05, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, **kwargs)`

¤

**Arguments:**

`input_size`

: The number of channels in the input array.`axis_name`

: The name of the batch axis to compute statistics over, as passed to`axis_name`

in`jax.vmap`

or`jax.pmap`

. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes.`eps`

: Value added to the denominator for numerical stability.`channelwise_affine`

: Whether the module has learnable channel-wise affine parameters.`momentum`

: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive.`inference`

: If`False`

then the batch means and variances will be calculated and used to update the running statistics. If`True`

then the running statistics are directly used for normalisation. This may be toggled with`equinox.tree_inference`

or overridden during`equinox.experimental.BatchNorm.__call__`

.

#####
`__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None, inference: Optional[bool] = None) -> Array`

¤

**Arguments:**

`x`

: A JAX array of shape`(input_size, dim_1, ..., dim_N)`

.`key`

: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)`inference`

: As per`equinox.experimental.BatchNorm.__init__`

. If`True`

or`False`

then it will take priority over`self.update_stats`

. If`None`

then the value from`self.update_stats`

will be used.

**Returns:**

A JAX array of shape `(input_size, dim_1, ..., dim_N)`

.

**Raises:**

A `NameError`

if no `vmap`

s are placed around this operation, or if this vmap
does not have a matching `axis_name`

.

####
```
equinox.experimental.SpectralNorm (Module)
```

¤

Applies spectral normalisation to a given parameter.

Given a weight matrix \(W\), and letting \(σ(W)\) denote (an approximation to) its largest singular value, then this computes \(W/σ(W)\).

The approximation \(σ(W)\) is computed using power iterations which are updated (as a side-effect) every time \(W/σ(W)\) is computed.

Spectral normalisation is particularly commonly used when training generative adversarial networks; see Spectral Normalization for Generative Adversarial Networks for more details and motivation.

`equinox.experimental.SpectralNorm`

should be used to replace an individual
parameter. (Unlike some libraries, not the layer containing that parameter.)

Example

To add spectral normalisation during model creation:

```
import equinox as eqx
import equinox.experimental as eqxe
import jax.random as jr
key = jr.PRNGKey(0)
linear = eqx.nn.Linear(2, 2, key=key)
sn_weight = eqxe.SpectralNorm(linear.weight, key=key)
linear = eqx.tree_at(lambda l: l.weight, linear, sn_weight)
```

Example

Alternatively, iterate over the model to add spectral normalisation after model creation:

```
import equinox as eqx
import equinox.experimental as eqxe
import jax
import jax.random as jr
import jax.tree_util as jtu
import functools as ft
key = jr.PRNGKey(0)
model_key, spectral_key = jr.split(key)
SN = ft.partial(eqxe.SpectralNorm, key=spectral_key)
def _is_linear(leaf):
return isinstance(leaf, eqx.nn.Linear)
def _apply_sn_to_linear(module):
if _is_linear(module):
module = eqx.tree_at(lambda m: m.weight, module, replace_fn=SN)
return module
def apply_sn(model):
return jtu.tree_map(_apply_sn_to_linear, model, is_leaf=_is_linear)
model = eqx.nn.MLP(2, 2, 2, 2, key=model_key)
model_with_sn = apply_sn(model)
```

Example

Switching the model to inference mode after training:

```
import equinox as eqx
import equinox.experimental as eqxe
import jax
def _is_sn(leaf):
return isinstance(leaf, eqxe.SpectralNorm)
def _set_inference_on_sn(module):
if _is_sn(module):
module = eqx.tree_at(lambda m: m.inference, module, True)
return module
def set_inference(model):
return jtu.tree_map(_set_inference_on_sn, model, is_leaf=_is_sn)
model = ... # set up model, train it, etc.
model = set_inference(model)
```

Warning

`equinox.experimental.SpectralNorm`

updates its running statistics as a side
effect of its forward pass. Side effects are quite unusual in JAX; as such
`SpectralNorm`

is considered experimental. Let us know how you find it!

#####
`__init__(self, weight: Array, num_power_iterations: int = 1, eps: float = 1e-12, inference: bool = False, *, key: jax.random.PRNGKey, **kwargs)`

¤

**Arguments:**

`weight`

: The parameter (a JAX array) to apply spectral normalisation to.`num_power_iterations`

: The number of power iterations to apply every time the array is accessed.`eps`

: Epsilon for numerical stability when calculating norms.`inference`

: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled with`equinox.tree_inference`

.`key`

: A`jax.random.PRNGKey`

used to provide randomness for initialisation. (Keyword only argument.)