Skip to content

Embeddings¤

equinox.nn.Embedding(equinox.Module) ¤

A simple lookup table that stores embeddings of a fixed size.

__init__(num_embeddings: int | None = None, embedding_size: int | None = None, weight: Float[Array, 'num_embeddings embedding_size'] | None = None, dtype=None, *, key: PRNGKeyArray | None = None) ¤

Arguments:

Embedding should be initialised with either:

  • num_embeddings: Size of embedding dictionary. Must be non-negative.
  • embedding_size: Size of each embedding vector. Must be non-negative.
  • dtype: The dtype to use for the embedding weights. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for initialisation of the embedding lookup table. (Keyword only argument.)

Or:

  • weight: The embedding lookup table, of shape (num_embeddings, embedding_size).
__call__(x: Int[ArrayLike, ''], *, key: PRNGKeyArray | None = None) -> Array ¤

Arguments:

  • x: The table index. Should be a scalar integer array.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (embedding_size,), from the x-th index of the embedding table.