Recurrent¤
equinox.nn.GRUCell (Module)
¤
A single step of a Gated Recurrent Unit (GRU).
Example
This is often used by wrapping it into a jax.lax.scan
. For example:
class Model(Module):
cell: GRUCell
def __init__(self, **kwargs):
self.cell = GRUCell(**kwargs)
def __call__(self, xs):
scan_fn = lambda state, input: (self.cell(input, state), None)
init_state = jnp.zeros(self.cell.hidden_size)
final_state, _ = jax.lax.scan(scan_fn, init_state, xs)
return final_state
__init__(self, input_size: int, hidden_size: int, use_bias: bool = True, dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
input_size
: The dimensionality of the input vector at each time step.hidden_size
: The dimensionality of the hidden state passed along between time steps.use_bias
: Whether to add on a bias after each update.dtype
: The dtype to use for all weights and biases in this GRU cell. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, input: Array, hidden: Array, *, key: Optional[PRNGKeyArray] = None)
¤
Arguments:
input
: The input, which should be a JAX array of shape(input_size,)
.hidden
: The hidden state, which should be a JAX array of shape(hidden_size,)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The updated hidden state, which is a JAX array of shape (hidden_size,)
.
equinox.nn.LSTMCell (Module)
¤
A single step of a Long-Short Term Memory unit (LSTM).
Example
This is often used by wrapping it into a jax.lax.scan
. For example:
class Model(Module):
cell: LSTMCell
def __init__(self, ...):
self.cell = LSTMCell(...)
def __call__(self, xs):
scan_fn = lambda state, input: (self.cell(input, state), None)
init_state = (jnp.zeros(self.cell.hidden_size),
jnp.zeros(self.cell.hidden_size))
(h, c), _ = jax.lax.scan(scan_fn, init_state, xs)
return h, c
__init__(self, input_size: int, hidden_size: int, use_bias: bool = True, dtype = None, *, key: PRNGKeyArray)
¤
Arguments:
input_size
: The dimensionality of the input vector at each time step.hidden_size
: The dimensionality of the hidden state passed along between time steps.use_bias
: Whether to add on a bias after each update.dtype
: The dtype to use for all weights and biases in this LSTM cell. Defaults to eitherjax.numpy.float32
orjax.numpy.float64
depending on whether JAX is in 64-bit mode.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, input, hidden, *, key = None)
¤
Arguments:
input
: The input, which should be a JAX array of shape(input_size,)
.hidden
: The hidden state, which should be a 2-tuple of JAX arrays, each of shape(hidden_size,)
.key
: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
Returns:
The updated hidden state, which is a 2-tuple of JAX arrays, each of shape
(hidden_size,)
.