Attention¤
equinox.nn.MultiheadAttention (Module)
¤
Computes
where:
-
The inputs are \(Q \in \mathbb{R}^{d_\text{seq} \times d_\text{query}}\), \(K \in \mathbb{R}^{d_\text{seq} \times d_\text{key}}\), \(V \in \mathbb{R}^{d_\text{seq} \times d_\text{value}}\). These are referred to as query, key, and value respectively. Meanwhile \(d_\text{seq}\) is the sequence length, and \(d_\text{query}\), \(d_\text{key}\), \(d_\text{value}\) are numbers of channels.
-
The trainable weights are \(W^Q_i \in \mathbb{R}^{d_\text{query} \times d_\text{qk}}\), \(W^K_i \in \mathbb{R}^{d_\text{key} \times d_\text{qk}}\), \(W^V_i \in \mathbb{R}^{d_\text{value} \times d_\text{vo}}\), \(W^O_i \in \mathbb{R}^{d_\text{vo} \times d_\text{output}}\), with \(i \in \{1, \ldots, h\}\), where \(h\) is the number of heads, and \(d_\text{qk}\), \(d_\text{vo}\), \(d_\text{output}\) are hyperparameters.
-
\(\text{Attention}\) is defined as \(\text{Attention}(\widetilde{Q}, \widetilde{K}, \widetilde{V}) = \text{softmax}(\frac{\widetilde{Q}\widetilde{K}^\intercal} {\sqrt{d_\text{qk}}})\widetilde{V}\).
Cite
@inproceedings{vaswani2017attention,
author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and
Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and
Kaiser, {\L}ukasz and Polosukhin, Illia},
booktitle={Advances in Neural Information Processing Systems},
publisher={Curran Associates, Inc.},
title={Attention is All You Need},
volume={30},
year={2017}
}
FAQ
Different software libraries often implement multihead attention in slightly different ways. Some of them will or won't add on biases by default. Most of them will fix the values of \(d_\text{qk}, d_\text{vo}, d_\text{output}\) in terms of \(d_\text{query}\) or \(d_\text{key}\) or \(d_\text{value}\). Equinox chooses to expose all of these as options.
Relative to the original Attention is All You Need paper: our \(d_\text{qk}\) is their "\(d_k\)". Our \(d_\text{vo}\) is their "\(d_\text{v}\)". They fix \(d_\text{query} = d_\text{key} = d_\text{value} = d_\text{output}\) and refer to it as "\(d_\text{model}\)".
__init__(self, num_heads: int, query_size: int, key_size: Optional[int] = None, value_size: Optional[int] = None, output_size: Optional[int] = None, qk_size: Optional[int] = None, vo_size: Optional[int] = None, use_query_bias: bool = False, use_key_bias: bool = False, use_value_bias: bool = False, use_output_bias: bool = False, dropout_p: float = 0.0, inference: bool = False, dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
num_heads
: Number of parallel attention heads \(h\).query_size
: Number of input channels for query \(Q\).key_size
: Number of input channels for key \(K\). Defaults toquery_size
.value_size
: Number of input channels for value \(V\). Defaults toquery_size
.output_size
: Number of output channels. Defaults toquery_size
.qk_size
: Number of channels to compare query and key over, per head. Defaults toquery_size // num_heads
.vo_size
: Number of channels to compare attention-weighted value and output over, per head. Defaults toquery_size // num_heads
.use_query_bias
: Whether to use a bias term in the query projections.use_key_bias
: Whether to use a bias term in the key projections.use_value_bias
: Whether to use a bias term in the value projections.use_output_bias
: Whether to use a bias term in the output projection.dropout_p
: Dropout probability on attention weights.inference
: Whether to actually apply dropout at all. IfTrue
then dropout is not applied. IfFalse
then dropout is applied. This may be toggled withequinox.nn.inference_mode
or overridden duringequinox.nn.MultiheadAttention.__call__
.dtype
: The dtype to use for all trainable parameters in this layer. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, query: Array, key_: Array, value: Array, mask: Union[NoneType, Array, Array] = None, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None, deterministic: Optional[bool] = None, process_heads: Optional[Callable[[Array, Array, Array], tuple[Array, Array, Array]]] = None) -> Array
¤
Arguments:
query
: Query embedding. Should be a JAX array of shape(query_seq_length, query_size)
.key_
: Key embedding. Should be a JAX array of shape(kv_seq_length, key_size)
.value
: Value embedding. Should be a JAX array of shape(kv_seq_length, value_size)
.mask
: Optional mask preventing attention to certain positions. Should either be a JAX array of shape(query_seq_length, kv_seq_length)
, or (for custom per-head masking)(num_heads, query_seq_length, kv_seq_length)
. A value ofFalse
at a position indicates that position should be ignored.key
: Ajax.random.PRNGKey
used for dropout. Unused ifdropout = 0
. (Keyword only argument.)inference
: Asequinox.nn.Dropout.__call__
. (Keyword only argument.)deterministic
: (Deprecated in favour ofinference
.)process_heads
: A function that takes in the query, key, and value heads and returns new query, key, and value heads. For example, this can be used to implement relative positional embeddings - see e.g.RotaryPositionalEmbedding
for an example. (Keyword only argument.)
Returns:
A JAX array of shape (query_seq_length, output_size)
.
equinox.nn.RotaryPositionalEmbedding (Module)
¤
A rotary positional encoding module, as described in the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding". While this module can be used in any context, it is particularly useful for providing positional information to transformer models.
Example
The following example demonstrates how to use RotaryPositionalEmbedding
in
a simple transformer model.
class TransformerBlock(eqx.Module):
rope_embeddings: RotaryPositionalEmbedding
def __init__(...):
self.rope_embeddings = RotaryPositionalEmbedding(...)
def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
return query_heads, key_heads, value_heads
x = self.mha_attention(... process_heads=process_heads)
...
Cite
RoFormer: Enhanced Transformer with Rotary Position Embedding
@misc{su2023roformer,
title={RoFormer: Enhanced Transformer with Rotary Position Embedding},
author={Jianlin Su and Yu Lu and Shengfeng Pan and Ahmed Murtadha and
Bo Wen and Yunfeng Liu},
year={2023},
eprint={arXiv:2104.09864},
}
__init__(self, embedding_size: int, theta: float = 10000.0, dtype: Any = <factory>)
¤
Arguments:
embedding_size
: Size of each embedding vector. Must be non-negative and even.theta
: The base frequency for the sinusoidal functions used in positional encoding. Specifies how quickly the inner-product will decay with relative distance between tokens. Larger values of theta will result in slower oscillations. Default is 10_000, as per the original paper.dtype
: The dtype to use for the precomputed frequencies. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: A JAX array of shape(seq_length, embedding_size)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (seq_length, embedding_size)
, with the rotary positional
encoding applied to the input.