Embeddings¤
equinox.nn.Embedding(equinox.Module)
¤
A simple lookup table that stores embeddings of a fixed size.
__init__(num_embeddings: int | None = None, embedding_size: int | None = None, weight: Float[Array, 'num_embeddings embedding_size'] | None = None, dtype=None, *, key: PRNGKeyArray | None = None)
¤
Arguments:
Embedding should be initialised with either:
num_embeddings: Size of embedding dictionary. Must be non-negative.embedding_size: Size of each embedding vector. Must be non-negative.dtype: The dtype to use for the embedding weights. Defaults to eitherjax.numpy.float32orjax.numpy.float64depending on whether JAX is in 64-bit mode.key: Ajax.random.PRNGKeyused to provide randomness for initialisation of the embedding lookup table. (Keyword only argument.)
Or:
weight: The embedding lookup table, of shape(num_embeddings, embedding_size).
__call__(x: Int[ArrayLike, ''], *, key: PRNGKeyArray | None = None) -> Array
¤
Arguments:
x: The table index. Should be a scalar integer array.key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (embedding_size,), from the x-th index of the embedding
table.