Embeddings¤
equinox.nn.Embedding(equinox.Module)
¤
A simple lookup table that stores embeddings of a fixed size.
__init__(num_embeddings: int | None = None, embedding_size: int | None = None, weight: Float[Array, 'num_embeddings embedding_size'] | None = None, dtype=None, *, key: PRNGKeyArray | None = None)
¤
Arguments:
Embedding should be initialised with either:
num_embeddings: Size of embedding dictionary. Must be non-negative.embedding_size: Size of each embedding vector. Must be non-negative.dtype: The dtype to use for the embedding weights. Defaults to eitherjax.numpy.float32orjax.numpy.float64depending on whether JAX is in 64-bit mode.key: Ajax.random.PRNGKeyused to provide randomness for initialisation of the embedding lookup table. (Keyword only argument.)
Or:
weight: The embedding lookup table, of shape(num_embeddings, embedding_size).
__call__(x: Int[ArrayLike, ''], *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x: The table index. Should be a scalar integer array.key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (embedding_size,), from the x-th index of the embedding
table.
equinox.nn.RotaryPositionalEmbedding(equinox.Module)
¤
A rotary positional encoding module, as described in the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding". While this module can be used in any context, it is particularly useful for providing positional information to transformer models.
Example
The following example demonstrates how to use RotaryPositionalEmbedding in
a simple transformer model.
class TransformerBlock(eqx.Module):
rope_embeddings: RotaryPositionalEmbedding
def __init__(...):
self.rope_embeddings = RotaryPositionalEmbedding(...)
def __call__(...):
def process_heads(
query_heads: Float[Array, "seq_length num_heads qk_size"],
key_heads: Float[Array, "seq_length num_heads qk_size"],
value_heads: Float[Array, "seq_length num_heads vo_size"]
) -> tuple[
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads qk_size"],
Float[Array, "seq_length num_heads vo_size"]
]:
query_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(query_heads)
key_heads = jax.vmap(self.rope_embeddings,
in_axes=1,
out_axes=1)(key_heads)
return query_heads, key_heads, value_heads
x = self.mha_attention(... process_heads=process_heads)
...
Cite
RoFormer: Enhanced Transformer with Rotary Position Embedding
@misc{su2023roformer,
title={RoFormer: Enhanced Transformer with Rotary Position Embedding},
author={Jianlin Su and Yu Lu and Shengfeng Pan and Ahmed Murtadha and
Bo Wen and Yunfeng Liu},
year={2023},
eprint={arXiv:2104.09864},
}
__init__(embedding_size: int, theta: float = 10000.0, dtype: Any = <factory>)
¤
Arguments:
embedding_size: Size of each embedding vector. Must be non-negative and even.theta: The base frequency for the sinusoidal functions used in positional encoding. Specifies how quickly the inner-product will decay with relative distance between tokens. Larger values of theta will result in slower oscillations. Default is 10_000, as per the original paper.dtype: The dtype to use for the precomputed frequencies. Defaults to eitherjax.numpy.float32orjax.numpy.float64depending on whether JAX is in 64-bit mode.
__call__(x: Float[Array, 'seq_length embedding_size'], *, key: PRNGKeyArray | None = None) -> Float[Array, 'seq_length embedding_size']
¤
Arguments:
x: A JAX array of shape(seq_length, embedding_size).key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (seq_length, embedding_size), with the rotary positional
encoding applied to the input.