Attention¤
equinox.nn.MultiheadAttention (Module)
¤
Computes
where:
-
The inputs are \(Q \in \mathbb{R}^{d_\text{seq} \times d_\text{query}}\), \(K \in \mathbb{R}^{d_\text{seq} \times d_\text{key}}\), \(V \in \mathbb{R}^{d_\text{seq} \times d_\text{value}}\). These are referred to as query, key, and value respectively. Meanwhile \(d_\text{seq}\) is the sequence length, and \(d_\text{query}\), \(d_\text{key}\), \(d_\text{value}\) are numbers of channels.
-
The trainable weights are \(W^Q_i \in \mathbb{R}^{d_\text{query} \times d_\text{qk}}\), \(W^K_i \in \mathbb{R}^{d_\text{key} \times d_\text{qk}}\), \(W^V_i \in \mathbb{R}^{d_\text{value} \times d_\text{vo}}\), \(W^O_i \in \mathbb{R}^{d_\text{vo} \times d_\text{output}}\), with \(i \in \{1, \ldots, h\}\), where \(h\) is the number of heads, and \(d_\text{qk}\), \(d_\text{vo}\), \(d_\text{output}\) are hyperparameters.
-
\(\text{Attention}\) is defined as \(\text{Attention}(\widetilde{Q}, \widetilde{K}, \widetilde{V}) = \text{softmax}(\frac{\widetilde{Q}\widetilde{K}^\intercal} {\sqrt{d_\text{qk}}})\widetilde{V}\).
Cite
@inproceedings{vaswani2017attention,
author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and
Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and
Kaiser, {\L}ukasz and Polosukhin, Illia},
booktitle={Advances in Neural Information Processing Systems},
publisher={Curran Associates, Inc.},
title={Attention is All You Need},
volume={30},
year={2017}
}
FAQ
Different software libraries often implement multihead attention in slightly different ways. Some of them will or won't add on biases by default. Most of them will fix the values of \(d_\text{qk}, d_\text{vo}, d_\text{output}\) in terms of \(d_\text{query}\) or \(d_\text{key}\) or \(d_\text{value}\). Equinox chooses to expose all of these as options.
Relative to the original Attention is All You Need paper: our \(d_\text{qk}\) is their "\(d_k\)". Our \(d_\text{vo}\) is their "\(d_\text{v}\)". They fix \(d_\text{query} = d_\text{key} = d_\text{value} = d_\text{output}\) and refer to it as "\(d_\text{model}\)".
__init__(self, num_heads: int, query_size: int, key_size: Optional[int] = None, value_size: Optional[int] = None, output_size: Optional[int] = None, qk_size: Optional[int] = None, vo_size: Optional[int] = None, use_query_bias: bool = False, use_key_bias: bool = False, use_value_bias: bool = False, use_output_bias: bool = False, dropout_p: float = 0.0, inference: bool = False, *, key: jax.random.PRNGKey, **kwargs)
¤
Arguments:
num_heads
: Number of parallel attention heads \(h\).query_size
: Number of input channels for query \(Q\).key_size
: Number of input channels for key \(K\). Defaults toquery_size
.value_size
: Number of input channels for value \(V\). Defaults toquery_size
.output_size
: Number of output channels. Defaults toquery_size
.qk_size
: Number of channels to compare query and key over, per head. Defaults toquery_size // num_heads
.vo_size
: Number of channels to compare attention-weighted value and output over, per head. Defaults toquery_size // num_heads
.use_query_bias
: Whether to use a bias term in the query projections.use_key_bias
: Whether to use a bias term in the key projections.use_value_bias
: Whether to use a bias term in the value projections.use_output_bias
: Whether to use a bias term in the output projection.dropout_p
: Dropout probability on attention weights.inference
: Whether to actually apply dropout at all. IfTrue
then dropout is not applied. IfFalse
then dropout is applied. This may be toggled withequinox.tree_inference
or overridden duringequinox.nn.MultiheadAttention.__call__
.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, query: Array, key_: Array, value: Array, mask: Optional[Array] = None, *, key: Optional[jax.random.PRNGKey] = None, inference: Optional[bool] = None, deterministic: Optional[bool] = None) -> Array
¤
Arguments:
query
: Query embedding. Should be a JAX array of shape(query_seq_length, query_size)
.key_
: Key embedding. Should be a JAX array of shape(kv_seq_length, key_size)
.value
: Value embedding. Should be a JAX array of shape(kv_seq_length, value_size)
.mask
: Optional mask preventing attention to certain positions. Should be a JAX array of shape(num_heads, query_seq_length, kv_seq_length)
.key
: Ajax.random.PRNGKey
used for dropout. Unused ifdropout = 0
. (Keyword only argument.)inference
: Asequinox.nn.Dropout.__call__
. (Keyword only argument.)deterministic
: (Deprecated in favour ofinference
.)
Returns:
A JAX array of shape (query_seq_length, output_size)
.