Skip to content

Embeddings¤

equinox.nn.Embedding (Module) ¤

A simple lookup table that stores embeddings of a fixed size.

__init__(self, num_embeddings: int, embedding_size: int, weight: Optional[Array] = None, *, key: jax.random.PRNGKey, **kwargs) ¤

Arguments:

  • num_embeddings: Size of embedding dictionary.
  • embedding_size: Size of each embedding vector.
  • weight: If given, the embedding lookup table. Will be generated randomly if not provided.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, x: Array, *, key: Optional[jax.random.PRNGKey] = None) -> Array ¤

Arguments:

  • x: The table index.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (embedding_size,), from the x-th index of the embedding table.