Skip to content


equinox.nn.Embedding (Module) ¤

A simple lookup table that stores embeddings of a fixed size.

__init__(self, num_embeddings: Optional[int] = None, embedding_size: Optional[int] = None, weight: Optional[Array] = None, dtype = None, *, key: Optional[PRNGKeyArray] = None) ¤


Embedding should be initialised with either:

  • num_embeddings: Size of embedding dictionary. Must be non-negative.
  • embedding_size: Size of each embedding vector. Must be non-negative.
  • dtype: The dtype to use for the embedding weights. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for initialisation of the embedding lookup table. (Keyword only argument.)


  • weight: The embedding lookup table, of shape (num_embeddings, embedding_size).
__call__(self, x: ArrayLike, *, key: Optional[PRNGKeyArray] = None) -> Array ¤


  • x: The table index. Should be a scalar integer array.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)


A JAX array of shape (embedding_size,), from the x-th index of the embedding table.