Skip to content

Attention¤

equinox.nn.MultiheadAttention (Module) ¤

Computes

\[\text{MultiheadAttention}(Q, K, V) = \sum_i \text{Attention}\left(QW^Q_i, KW^K_i, VW^V_i\right)W^O_i\]

where:

  • The inputs are \(Q \in \mathbb{R}^{d_\text{seq} \times d_\text{query}}\), \(K \in \mathbb{R}^{d_\text{seq} \times d_\text{key}}\), \(V \in \mathbb{R}^{d_\text{seq} \times d_\text{value}}\). These are referred to as query, key, and value respectively. Meanwhile \(d_\text{seq}\) is the sequence length, and \(d_\text{query}\), \(d_\text{key}\), \(d_\text{value}\) are numbers of channels.

  • The trainable weights are \(W^Q_i \in \mathbb{R}^{d_\text{query} \times d_\text{qk}}\), \(W^K_i \in \mathbb{R}^{d_\text{key} \times d_\text{qk}}\), \(W^V_i \in \mathbb{R}^{d_\text{value} \times d_\text{vo}}\), \(W^O_i \in \mathbb{R}^{d_\text{vo} \times d_\text{output}}\), with \(i \in \{1, \ldots, h\}\), where \(h\) is the number of heads, and \(d_\text{qk}\), \(d_\text{vo}\), \(d_\text{output}\) are hyperparameters.

  • \(\text{Attention}\) is defined as \(\text{Attention}(\widetilde{Q}, \widetilde{K}, \widetilde{V}) = \text{softmax}(\frac{\widetilde{Q}\widetilde{K}^\intercal} {\sqrt{d_\text{qk}}})\widetilde{V}\).

Cite

Attention is All You Need

@inproceedings{vaswani2017attention,
    author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and
            Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and
            Kaiser, {\L}ukasz and Polosukhin, Illia},
    booktitle={Advances in Neural Information Processing Systems},
    publisher={Curran Associates, Inc.},
    title={Attention is All You Need},
    volume={30},
    year={2017}
}

FAQ

Different software libraries often implement multihead attention in slightly different ways. Some of them will or won't add on biases by default. Most of them will fix the values of \(d_\text{qk}, d_\text{vo}, d_\text{output}\) in terms of \(d_\text{query}\) or \(d_\text{key}\) or \(d_\text{value}\). Equinox chooses to expose all of these as options.

Relative to the original Attention is All You Need paper: our \(d_\text{qk}\) is their "\(d_k\)". Our \(d_\text{vo}\) is their "\(d_\text{v}\)". They fix \(d_\text{query} = d_\text{key} = d_\text{value} = d_\text{output}\) and refer to it as "\(d_\text{model}\)".

__init__(self, num_heads: int, query_size: int, key_size: Optional[int] = None, value_size: Optional[int] = None, output_size: Optional[int] = None, qk_size: Optional[int] = None, vo_size: Optional[int] = None, use_query_bias: bool = False, use_key_bias: bool = False, use_value_bias: bool = False, use_output_bias: bool = False, dropout_p: float = 0.0, inference: bool = False, *, key: jax.random.PRNGKey, **kwargs) ¤

Arguments:

  • num_heads: Number of parallel attention heads \(h\).
  • query_size: Number of input channels for query \(Q\).
  • key_size: Number of input channels for key \(K\). Defaults to query_size.
  • value_size: Number of input channels for value \(V\). Defaults to query_size.
  • output_size: Number of output channels. Defaults to query_size.
  • qk_size: Number of channels to compare query and key over, per head. Defaults to query_size // num_heads.
  • vo_size: Number of channels to compare attention-weighted value and output over, per head. Defaults to query_size // num_heads.
  • use_query_bias: Whether to use a bias term in the query projections.
  • use_key_bias: Whether to use a bias term in the key projections.
  • use_value_bias: Whether to use a bias term in the value projections.
  • use_output_bias: Whether to use a bias term in the output projection.
  • dropout_p: Dropout probability on attention weights.
  • inference: Whether to actually apply dropout at all. If True then dropout is not applied. If False then dropout is applied. This may be toggled with equinox.tree_inference or overridden during equinox.nn.MultiheadAttention.__call__.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, query: Array, key_: Array, value: Array, mask: Optional[Array] = None, *, key: Optional[jax.random.PRNGKey] = None, inference: Optional[bool] = None, deterministic: Optional[bool] = None) -> Array ¤

Arguments:

  • query: Query embedding. Should be a JAX array of shape (query_seq_length, query_size).
  • key_: Key embedding. Should be a JAX array of shape (kv_seq_length, key_size).
  • value: Value embedding. Should be a JAX array of shape (kv_seq_length, value_size).
  • mask: Optional mask preventing attention to certain positions. Should be a JAX array of shape (num_heads, query_seq_length, kv_seq_length).
  • key: A jax.random.PRNGKey used for dropout. Unused if dropout = 0. (Keyword only argument.)
  • inference: As equinox.nn.Dropout.__call__. (Keyword only argument.)
  • deterministic: (Deprecated in favour of inference.)

Returns:

A JAX array of shape (query_seq_length, output_size).