Skip to content

Embeddings¤

equinox.nn.Embedding(equinox.Module) ¤

A simple lookup table that stores embeddings of a fixed size.

__init__(num_embeddings: int | None = None, embedding_size: int | None = None, weight: Float[Array, 'num_embeddings embedding_size'] | None = None, dtype=None, *, key: PRNGKeyArray | None = None) ¤

Arguments:

Embedding should be initialised with either:

  • num_embeddings: Size of embedding dictionary. Must be non-negative.
  • embedding_size: Size of each embedding vector. Must be non-negative.
  • dtype: The dtype to use for the embedding weights. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for initialisation of the embedding lookup table. (Keyword only argument.)

Or:

  • weight: The embedding lookup table, of shape (num_embeddings, embedding_size).
__call__(x: Int[ArrayLike, ''], *, key: PRNGKeyArray | None = None) -> Array ¤

Arguments:

  • x: The table index. Should be a scalar integer array.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (embedding_size,), from the x-th index of the embedding table.


equinox.nn.RotaryPositionalEmbedding(equinox.Module) ¤

A rotary positional encoding module, as described in the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding". While this module can be used in any context, it is particularly useful for providing positional information to transformer models.

Example

The following example demonstrates how to use RotaryPositionalEmbedding in a simple transformer model.

class TransformerBlock(eqx.Module):
    rope_embeddings: RotaryPositionalEmbedding

    def __init__(...):
        self.rope_embeddings = RotaryPositionalEmbedding(...)

    def __call__(...):
        def process_heads(
            query_heads: Float[Array, "seq_length num_heads qk_size"],
            key_heads: Float[Array, "seq_length num_heads qk_size"],
            value_heads: Float[Array, "seq_length num_heads vo_size"]
        ) -> tuple[
            Float[Array, "seq_length num_heads qk_size"],
            Float[Array, "seq_length num_heads qk_size"],
            Float[Array, "seq_length num_heads vo_size"]
        ]:
            query_heads = jax.vmap(self.rope_embeddings,
                                   in_axes=1,
                                   out_axes=1)(query_heads)
            key_heads = jax.vmap(self.rope_embeddings,
                                 in_axes=1,
                                 out_axes=1)(key_heads)

            return query_heads, key_heads, value_heads

        x = self.mha_attention(... process_heads=process_heads)
        ...
Cite

RoFormer: Enhanced Transformer with Rotary Position Embedding

@misc{su2023roformer,
  title={RoFormer: Enhanced Transformer with Rotary Position Embedding},
  author={Jianlin Su and Yu Lu and Shengfeng Pan and Ahmed Murtadha and
      Bo Wen and Yunfeng Liu},
  year={2023},
  eprint={arXiv:2104.09864},
}
__init__(embedding_size: int, theta: float = 10000.0, dtype: Any = <factory>) ¤

Arguments:

  • embedding_size: Size of each embedding vector. Must be non-negative and even.
  • theta: The base frequency for the sinusoidal functions used in positional encoding. Specifies how quickly the inner-product will decay with relative distance between tokens. Larger values of theta will result in slower oscillations. Default is 10_000, as per the original paper.
  • dtype: The dtype to use for the precomputed frequencies. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
__call__(x: Float[Array, 'seq_length embedding_size'], *, key: PRNGKeyArray | None = None) -> Float[Array, 'seq_length embedding_size'] ¤

Arguments:

  • x: A JAX array of shape (seq_length, embedding_size).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (seq_length, embedding_size), with the rotary positional encoding applied to the input.