Skip to content

Attention¤

equinox.nn.MultiheadAttention (Module) ¤

Computes

\[\text{MultiheadAttention}(Q, K, V) = \sum_i \text{Attention}\left(QW^Q_i, KW^K_i, VW^V_i\right)W^O_i\]

where:

  • The inputs are \(Q \in \mathbb{R}^{d_\text{seq} \times d_\text{query}}\), \(K \in \mathbb{R}^{d_\text{seq} \times d_\text{key}}\), \(V \in \mathbb{R}^{d_\text{seq} \times d_\text{value}}\). These are referred to as query, key, and value respectively. Meanwhile \(d_\text{seq}\) is the sequence length, and \(d_\text{query}\), \(d_\text{key}\), \(d_\text{value}\) are numbers of channels.

  • The trainable weights are \(W^Q_i \in \mathbb{R}^{d_\text{query} \times d_\text{qk}}\), \(W^K_i \in \mathbb{R}^{d_\text{key} \times d_\text{qk}}\), \(W^V_i \in \mathbb{R}^{d_\text{value} \times d_\text{vo}}\), \(W^O_i \in \mathbb{R}^{d_\text{vo} \times d_\text{output}}\), with \(i \in \{1, \ldots, h\}\), where \(h\) is the number of heads, and \(d_\text{qk}\), \(d_\text{vo}\), \(d_\text{output}\) are hyperparameters.

  • \(\text{Attention}\) is defined as \(\text{Attention}(\widetilde{Q}, \widetilde{K}, \widetilde{V}) = \text{softmax}(\frac{\widetilde{Q}\widetilde{K}^\intercal} {\sqrt{d_\text{qk}}})\widetilde{V}\).

Cite

Attention is All You Need

@inproceedings{vaswani2017attention,
    author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and
            Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and
            Kaiser, {\L}ukasz and Polosukhin, Illia},
    booktitle={Advances in Neural Information Processing Systems},
    publisher={Curran Associates, Inc.},
    title={Attention is All You Need},
    volume={30},
    year={2017}
}

FAQ

Different software libraries often implement multihead attention in slightly different ways. Some of them will or won't add on biases by default. Most of them will fix the values of \(d_\text{qk}, d_\text{vo}, d_\text{output}\) in terms of \(d_\text{query}\) or \(d_\text{key}\) or \(d_\text{value}\). Equinox chooses to expose all of these as options.

Relative to the original Attention is All You Need paper: our \(d_\text{qk}\) is their "\(d_k\)". Our \(d_\text{vo}\) is their "\(d_\text{v}\)". They fix \(d_\text{query} = d_\text{key} = d_\text{value} = d_\text{output}\) and refer to it as "\(d_\text{model}\)".

__init__(self, num_heads: int, query_size: int, key_size: Optional[int] = None, value_size: Optional[int] = None, output_size: Optional[int] = None, qk_size: Optional[int] = None, vo_size: Optional[int] = None, use_query_bias: bool = False, use_key_bias: bool = False, use_value_bias: bool = False, use_output_bias: bool = False, dropout_p: float = 0.0, inference: bool = False, dtype = None, *, key: PRNGKeyArray) ¤

Arguments:

  • num_heads: Number of parallel attention heads \(h\).
  • query_size: Number of input channels for query \(Q\).
  • key_size: Number of input channels for key \(K\). Defaults to query_size.
  • value_size: Number of input channels for value \(V\). Defaults to query_size.
  • output_size: Number of output channels. Defaults to query_size.
  • qk_size: Number of channels to compare query and key over, per head. Defaults to query_size // num_heads.
  • vo_size: Number of channels to compare attention-weighted value and output over, per head. Defaults to query_size // num_heads.
  • use_query_bias: Whether to use a bias term in the query projections.
  • use_key_bias: Whether to use a bias term in the key projections.
  • use_value_bias: Whether to use a bias term in the value projections.
  • use_output_bias: Whether to use a bias term in the output projection.
  • dropout_p: Dropout probability on attention weights.
  • inference: Whether to actually apply dropout at all. If True then dropout is not applied. If False then dropout is applied. This may be toggled with equinox.nn.inference_mode or overridden during equinox.nn.MultiheadAttention.__call__.
  • dtype: The dtype to use for all trainable parameters in this layer. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, query: Array, key_: Array, value: Array, mask: Union[NoneType, Array, Array] = None, *, key: Optional[PRNGKeyArray] = None, inference: Optional[bool] = None, deterministic: Optional[bool] = None, process_heads: Optional[Callable[[Array, Array, Array], tuple[Array, Array, Array]]] = None) -> Array ¤

Arguments:

  • query: Query embedding. Should be a JAX array of shape (query_seq_length, query_size).
  • key_: Key embedding. Should be a JAX array of shape (kv_seq_length, key_size).
  • value: Value embedding. Should be a JAX array of shape (kv_seq_length, value_size).
  • mask: Optional mask preventing attention to certain positions. Should either be a JAX array of shape (query_seq_length, kv_seq_length), or (for custom per-head masking) (num_heads, query_seq_length, kv_seq_length). A value of False at a position indicates that position should be ignored.
  • key: A jax.random.PRNGKey used for dropout. Unused if dropout = 0. (Keyword only argument.)
  • inference: As equinox.nn.Dropout.__call__. (Keyword only argument.)
  • deterministic: (Deprecated in favour of inference.)
  • process_heads: A function that takes in the query, key, and value heads and returns new query, key, and value heads. For example, this can be used to implement relative positional embeddings - see e.g. RotaryPositionalEmbeddingfor an example. (Keyword only argument.)

Returns:

A JAX array of shape (query_seq_length, output_size).


equinox.nn.RotaryPositionalEmbedding (Module) ¤

A rotary positional encoding module, as described in the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding". While this module can be used in any context, it is particularly useful for providing positional information to transformer models.

Example

The following example demonstrates how to use RotaryPositionalEmbedding in a simple transformer model.

class TransformerBlock(eqx.Module):
    rope_embeddings: RotaryPositionalEmbedding

    def __init__(...):
        self.rope_embeddings = RotaryPositionalEmbedding(...)

    def __call__(...):
        def process_heads(
            query_heads: Float[Array, "seq_length num_heads qk_size"],
            key_heads: Float[Array, "seq_length num_heads qk_size"],
            value_heads: Float[Array, "seq_length num_heads vo_size"]
        ) -> tuple[
            Float[Array, "seq_length num_heads qk_size"],
            Float[Array, "seq_length num_heads qk_size"],
            Float[Array, "seq_length num_heads vo_size"]
        ]:
            query_heads = jax.vmap(self.rope_embeddings,
                                   in_axes=1,
                                   out_axes=1)(query_heads)
            key_heads = jax.vmap(self.rope_embeddings,
                                 in_axes=1,
                                 out_axes=1)(key_heads)

            return query_heads, key_heads, value_heads

        x = self.mha_attention(... process_heads=process_heads)
        ...
Cite

RoFormer: Enhanced Transformer with Rotary Position Embedding

@misc{su2023roformer,
  title={RoFormer: Enhanced Transformer with Rotary Position Embedding},
  author={Jianlin Su and Yu Lu and Shengfeng Pan and Ahmed Murtadha and
      Bo Wen and Yunfeng Liu},
  year={2023},
  eprint={arXiv:2104.09864},
}
__init__(self, embedding_size: int, theta: float = 10000.0, dtype: Any = <factory>) ¤

Arguments:

  • embedding_size: Size of each embedding vector. Must be non-negative and even.
  • theta: The base frequency for the sinusoidal functions used in positional encoding. Specifies how quickly the inner-product will decay with relative distance between tokens. Larger values of theta will result in slower oscillations. Default is 10_000, as per the original paper.
  • dtype: The dtype to use for the precomputed frequencies. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
__call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array ¤

Arguments:

  • x: A JAX array of shape (seq_length, embedding_size).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (seq_length, embedding_size), with the rotary positional encoding applied to the input.