Normalisation¤
equinox.nn.LayerNorm (Module)
¤
Computes a mean and standard deviation over the whole input array, and uses these to normalise the whole array. Optionally applies an elementwise affine transformation afterwards.
Given an input array \(x\), this layer computes
where \(\gamma\), \(\beta\) have the same shape as \(x\) if elementwise_affine=True
,
and \(\gamma = 1\), \(\beta = 0\) if elementwise_affine=False
.
Cite
@article{ba2016layer,
author={Jimmy Lei Ba, Jamie Ryan Kriso, Geoffrey E. Hinton},
title={Layer Normalization},
year={2016},
journal={arXiv:1607.06450},
}
FAQ
If you need to normalise over only some input dimensions, then this can be achieved by vmap'ing. For example the following will compute statistics over every dimension except the first:
layer = LayerNorm(...)
array = jax.vmap(layer)(array)
__init__(self, shape: Union[NoneType, int, Sequence[int]], eps: float = 1e-05, elementwise_affine: bool = True, **kwargs)
¤
Arguments:
shape
: Input shape. May be left unspecified (e.g. justNone
) ifelementwise_affine=False
.eps
: Value added to denominator for numerical stability.elementwise_affine
: Whether the module has learnable affine parameters.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
¤
Arguments:
x
: A JAX array whose shape is given byshape
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape shape
.
equinox.nn.GroupNorm (Module)
¤
Splits the first dimension ("channels") into groups of fixed size. Computes a mean and standard deviation over the contents of each group, and uses these to normalise the group. Optionally applies a channel-wise affine transformation afterwards.
Given an input array \(x\) of shape (channels, ...)
, this layer splits this up into
groups
-many arrays \(x_i\) each of shape (channels/groups, ...)
, and for each one
computes
where \(\gamma_i\), \(\beta_i\) have shape (channels/groups,)
if
channelwise_affine=True
, and \(\gamma = 1\), \(\beta = 0\) if
channelwise_affine=False
.
Cite
@article{wu2018group,
author={Yuxin Wu and Kaiming He},
title={Group Normalization},
year={2018},
journal={arXiv:1803.08494},
}
__init__(self, groups: int, channels: Optional[int] = None, eps: float = 1e-05, channelwise_affine: bool = True, **kwargs)
¤
Arguments:
groups
: The number of groups to split the input into.channels
: The number of input channels. May be left unspecified (e.g. justNone
) ifchannelwise_affine=False
.eps
: Value added to denominator for numerical stability.channelwise_affine
: Whether the module has learnable affine parameters.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array
¤
Arguments:
x
: A JAX array of shape(channels, ...)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (channels, ...)
.
equinox.experimental.BatchNorm (Module)
¤
Computes a mean and standard deviation over the batch and spatial dimensions of an array, and uses these to normalise the whole array. Optionally applies a channelwise affine transformation afterwards.
Given an input array \(x = [x_1, ... x_C]\) with \(C\) channels, this layer computes
for all \(i\). Here \(*\) denotes elementwise multiplication and \(\gamma\), \(\beta\) have
shape \((C,)\) if channelwise_affine=True
and \(\gamma = 1\), \(\beta = 0\) if
channelwise_affine=False
. Expectations are computed over all spatial dimensions
and over the batch dimension, and updated batch-by-batch according to momentum
.
Example
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# BatchNorm will automatically update its running statistics internally.
Warning
This layer must be used inside of a vmap
or pmap
with a matching
axis_name
. (Not doing so will raise a NameError
.)
Warning
equinox.experimental.BatchNorm
updates its running statistics as a side
effect of its forward pass. Side effects are quite unusual in JAX; as such
BatchNorm
is considered experimental. Let us know how you find it!
__init__(self, input_size: int, axis_name: Union[Hashable, Sequence[Hashable]], eps: float = 1e-05, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, **kwargs)
¤
Arguments:
input_size
: The number of channels in the input array.axis_name
: The name of the batch axis to compute statistics over, as passed toaxis_name
injax.vmap
orjax.pmap
. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes.eps
: Value added to the denominator for numerical stability.channelwise_affine
: Whether the module has learnable channel-wise affine parameters.momentum
: The rate at which to update the running statistics. Should be a value between 0 and 1 exclusive.inference
: IfFalse
then the batch means and variances will be calculated and used to update the running statistics. IfTrue
then the running statistics are directly used for normalisation. This may be toggled withequinox.tree_inference
or overridden duringequinox.experimental.BatchNorm.__call__
.
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None, inference: Optional[bool] = None) -> Array
¤
Arguments:
x
: A JAX array of shape(input_size, dim_1, ..., dim_N)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)inference
: As perequinox.experimental.BatchNorm.__init__
. IfTrue
orFalse
then it will take priority overself.update_stats
. IfNone
then the value fromself.update_stats
will be used.
Returns:
A JAX array of shape (input_size, dim_1, ..., dim_N)
.
Raises:
A NameError
if no vmap
s are placed around this operation, or if this vmap
does not have a matching axis_name
.
equinox.experimental.SpectralNorm (Module)
¤
Applies spectral normalisation to a given parameter.
Given a weight matrix \(W\), and letting \(σ(W)\) denote (an approximation to) its largest singular value, then this computes \(W/σ(W)\).
The approximation \(σ(W)\) is computed using power iterations which are updated (as a side-effect) every time \(W/σ(W)\) is computed.
Spectral normalisation is particularly commonly used when training generative adversarial networks; see Spectral Normalization for Generative Adversarial Networks for more details and motivation.
equinox.experimental.SpectralNorm
should be used to replace an individual
parameter. (Unlike some libraries, not the layer containing that parameter.)
Example
To add spectral normalisation during model creation:
import equinox as eqx
import equinox.experimental as eqxe
import jax.random as jr
key = jr.PRNGKey(0)
linear = eqx.nn.Linear(2, 2, key=key)
sn_weight = eqxe.SpectralNorm(linear.weight, key=key)
linear = eqx.tree_at(lambda l: l.weight, linear, sn_weight)
Example
Alternatively, iterate over the model to add spectral normalisation after model creation:
import equinox as eqx
import equinox.experimental as eqxe
import jax
import jax.random as jr
import jax.tree_util as jtu
import functools as ft
key = jr.PRNGKey(0)
model_key, spectral_key = jr.split(key)
SN = ft.partial(eqxe.SpectralNorm, key=spectral_key)
def _is_linear(leaf):
return isinstance(leaf, eqx.nn.Linear)
def _apply_sn_to_linear(module):
if _is_linear(module):
module = eqx.tree_at(lambda m: m.weight, module, replace_fn=SN)
return module
def apply_sn(model):
return jtu.tree_map(_apply_sn_to_linear, model, is_leaf=_is_linear)
model = eqx.nn.MLP(2, 2, 2, 2, key=model_key)
model_with_sn = apply_sn(model)
Example
Switching the model to inference mode after training:
import equinox as eqx
import equinox.experimental as eqxe
import jax
def _is_sn(leaf):
return isinstance(leaf, eqxe.SpectralNorm)
def _set_inference_on_sn(module):
if _is_sn(module):
module = eqx.tree_at(lambda m: m.inference, module, True)
return module
def set_inference(model):
return jtu.tree_map(_set_inference_on_sn, model, is_leaf=_is_sn)
model = ... # set up model, train it, etc.
model = set_inference(model)
Warning
equinox.experimental.SpectralNorm
updates its running statistics as a side
effect of its forward pass. Side effects are quite unusual in JAX; as such
SpectralNorm
is considered experimental. Let us know how you find it!
__init__(self, weight: Array, num_power_iterations: int = 1, eps: float = 1e-12, inference: bool = False, *, key: jax.random.PRNGKey, **kwargs)
¤
Arguments:
weight
: The parameter (a JAX array) to apply spectral normalisation to.num_power_iterations
: The number of power iterations to apply every time the array is accessed.eps
: Epsilon for numerical stability when calculating norms.inference
: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled withequinox.tree_inference
.key
: Ajax.random.PRNGKey
used to provide randomness for initialisation. (Keyword only argument.)