Skip to content

Normalisation¤

equinox.nn.LayerNorm (Module) ¤

Computes a mean and standard deviation over the whole input array, and uses these to normalise the whole array. Optionally applies an elementwise affine transformation afterwards.

Given an input array \(x\), this layer computes

\[\frac{x - \mathbb{E}[x]}{\sqrt{\text{Var}[x] + \varepsilon}} * \gamma + \beta\]

where \(\gamma\), \(\beta\) have the same shape as \(x\) if elementwise_affine=True, and \(\gamma = 1\), \(\beta = 0\) if elementwise_affine=False.

Cite

Layer Normalization

@article{ba2016layer,
    author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
    title={Layer Normalization},
    year={2016},
    journal={arXiv:1607.06450},
}

FAQ

If you need to normalise over only some input dimensions, then this can be achieved by vmap'ing. For example the following will compute statistics over every dimension except the first:

layer = LayerNorm(...)
array = jax.vmap(layer)(array)

__init__(self, shape: Union[NoneType, int, Sequence[int]], eps: float = 1e-05, elementwise_affine: bool = True, **kwargs) ¤

Arguments:

  • shape: Input shape. May be left unspecified (e.g. just None) if elementwise_affine=False.
  • eps: Value added to denominator for numerical stability.
  • elementwise_affine: Whether the module has learnable affine parameters.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: A JAX array whose shape is given by shape.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape shape.


equinox.nn.GroupNorm (Module) ¤

Splits the first dimension ("channels") into groups of fixed size. Computes a mean and standard deviation over the contents of each group, and uses these to normalise the group. Optionally applies a channel-wise affine transformation afterwards.

Given an input array \(x\) of shape (channels, ...), this layer splits this up into groups-many arrays \(x_i\) each of shape (channels/groups, ...), and for each one computes

\[\frac{x_i - \mathbb{E}[x_i]}{\sqrt{\text{Var}[x_i] + \varepsilon}} * \gamma_i + \beta_i\]

where \(\gamma_i\), \(\beta_i\) have shape (channels/groups,) if channelwise_affine=True, and \(\gamma = 1\), \(\beta = 0\) if channelwise_affine=False.

Cite

Group Normalization

@article{wu2018group,
    author={Yuxin Wu and Kaiming He},
    title={Group Normalization},
    year={2018},
    journal={arXiv:1803.08494},
}
__init__(self, groups: int, channels: Optional[int] = None, eps: float = 1e-05, channelwise_affine: bool = True, **kwargs) ¤

Arguments:

  • groups: The number of groups to split the input into.
  • channels: The number of input channels. May be left unspecified (e.g. just None) if channelwise_affine=False.
  • eps: Value added to denominator for numerical stability.
  • channelwise_affine: Whether the module has learnable affine parameters.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: A JAX array of shape (channels, ...).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (channels, ...).


equinox.experimental.BatchNorm (Module) ¤

Computes a mean and standard deviation over the batch and spatial dimensions of an array, and uses these to normalise the whole array. Optionally applies a channelwise affine transformation afterwards.

Given an input array \(x = [x_1, ... x_C]\) with \(C\) channels, this layer computes

\[\frac{x_i - \mathbb{E}[x_i]}{\sqrt{\text{Var}[x_i] + \varepsilon}} * \gamma_i + \beta_i\]

for all \(i\). Here \(*\) denotes elementwise multiplication and \(\gamma\), \(\beta\) have shape \((C,)\) if channelwise_affine=True and \(\gamma = 1\), \(\beta = 0\) if channelwise_affine=False. Expectations are computed over all spatial dimensions and over the batch dimension, and updated batch-by-batch according to momentum.

Example

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# BatchNorm will automatically update its running statistics internally.

Warning

This layer must be used inside of a vmap or pmap with a matching axis_name. (Not doing so will raise a NameError.)

Warning

equinox.experimental.BatchNorm updates its running statistics as a side effect of its forward pass. Side effects are quite unusual in JAX; as such BatchNorm is considered experimental. Let us know how you find it!

__init__(self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], eps: float = 1e-05, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, **kwargs) ¤

Arguments:

  • input_size: The number of channels in the input array.
  • axis_name: The name of the batch axis to compute statistics over, as passed to axis_name in jax.vmap or jax.pmap. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes.
  • eps: Value added to the denominator for numerical stability.
  • channelwise_affine: Whether the module has learnable channel-wise affine parameters.
  • momentum: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive.
  • inference: If False then the batch means and variances will be calculated and used to update the running statistics. If True then the running statistics are directly used for normalisation. This may be toggled with equinox.tree_inference or overridden during equinox.experimental.BatchNorm.__call__.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None, inference: Optional[bool] = None) -> Array ¤

Arguments:

  • x: A JAX array of shape (input_size, dim_1, ..., dim_N).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
  • inference: As per equinox.experimental.BatchNorm.__init__. If True or False then it will take priority over self.update_stats. If None then the value from self.update_stats will be used.

Returns:

A JAX array of shape (input_size, dim_1, ..., dim_N).

Raises:

A NameError if no vmaps are placed around this operation, or if this vmap does not have a matching axis_name.


equinox.experimental.SpectralNorm (Module) ¤

Applies spectral normalisation to a given parameter.

Given a weight matrix \(W\), and letting \(σ(W)\) denote (an approximation to) its largest singular value, then this computes \(W/σ(W)\).

The approximation \(σ(W)\) is computed using power iterations which are updated (as a side-effect) every time \(W/σ(W)\) is computed.

Spectral normalisation is particularly commonly used when training generative adversarial networks; see Spectral Normalization for Generative Adversarial Networks for more details and motivation.

equinox.experimental.SpectralNorm should be used to replace an individual parameter. (Unlike some libraries, not the layer containing that parameter.)

Example

To add spectral normalisation during model creation:

import equinox as eqx
import equinox.experimental as eqxe
import jax.random as jr

key = jr.PRNGKey(0)
linear = eqx.nn.Linear(2, 2, key=key)
sn_weight = eqxe.SpectralNorm(linear.weight, key=key)
linear = eqx.tree_at(lambda l: l.weight, linear, sn_weight)

Example

Alternatively, iterate over the model to add spectral normalisation after model creation:

import equinox as eqx
import equinox.experimental as eqxe
import jax
import jax.random as jr
import jax.tree_util as jtu
import functools as ft

key = jr.PRNGKey(0)
model_key, spectral_key = jr.split(key)
SN = ft.partial(eqxe.SpectralNorm, key=spectral_key)

def _is_linear(leaf):
    return isinstance(leaf, eqx.nn.Linear)

def _apply_sn_to_linear(module):
    if _is_linear(module):
        module = eqx.tree_at(lambda m: m.weight, module, replace_fn=SN)
    return module

def apply_sn(model):
    return jtu.tree_map(_apply_sn_to_linear, model, is_leaf=_is_linear)

model = eqx.nn.MLP(2, 2, 2, 2, key=model_key)
model_with_sn = apply_sn(model)

Example

Switching the model to inference mode after training:

import equinox as eqx
import equinox.experimental as eqxe
import jax

def _is_sn(leaf):
    return isinstance(leaf, eqxe.SpectralNorm)

def _set_inference_on_sn(module):
    if _is_sn(module):
        module = eqx.tree_at(lambda m: m.inference, module, True)
    return module

def set_inference(model):
    return jtu.tree_map(_set_inference_on_sn, model, is_leaf=_is_sn)

model = ...  # set up model, train it, etc.
model = set_inference(model)

Warning

equinox.experimental.SpectralNorm updates its running statistics as a side effect of its forward pass. Side effects are quite unusual in JAX; as such SpectralNorm is considered experimental. Let us know how you find it!

__init__(self, weight: Array, num_power_iterations: int = 1, eps: float = 1e-12, inference: bool = False, *, key: jax.random.PRNGKey, **kwargs) ¤

Arguments:

  • weight: The parameter (a JAX array) to apply spectral normalisation to.
  • num_power_iterations: The number of power iterations to apply every time the array is accessed.
  • eps: Epsilon for numerical stability when calculating norms.
  • inference: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled with equinox.tree_inference.
  • key: A jax.random.PRNGKey used to provide randomness for initialisation. (Keyword only argument.)