Embeddings¤
equinox.nn.Embedding (Module)
¤
A simple lookup table that stores embeddings of a fixed size.
__init__(self, num_embeddings: Optional[int] = None, embedding_size: Optional[int] = None, weight: Optional[Array] = None, dtype = None, *, key: Optional[PRNGKeyArray] = None)
¤
Arguments:
Embedding
should be initialised with either:
num_embeddings
: Size of embedding dictionary. Must be non-negative.embedding_size
: Size of each embedding vector. Must be non-negative.dtype
: The dtype to use for the embedding weights. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for initialisation of the embedding lookup table. (Keyword only argument.)
Or:
weight
: The embedding lookup table, of shape(num_embeddings, embedding_size)
.
__call__(self, x: ArrayLike, *, key: Optional[PRNGKeyArray] = None) -> Array
¤
Arguments:
x
: The table index. Should be a scalar integer array.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
A JAX array of shape (embedding_size,)
, from the x-th index of the embedding
table.