Skip to content

Recurrent¤

equinox.nn.GRUCell (Module) ¤

A single step of a Gated Recurrent Unit (GRU).

Example

This is often used by wrapping it into a jax.lax.scan. For example:

class Model(Module):
    cell: GRUCell

    def __init__(self, **kwargs):
        self.cell = GRUCell(**kwargs)

    def __call__(self, xs):
        scan_fn = lambda state, input: (self.cell(input, state), None)
        init_state = jnp.zeros(self.cell.hidden_size)
        final_state, _ = jax.lax.scan(scan_fn, init_state, xs)
        return final_state
__init__(self, input_size: int, hidden_size: int, use_bias: bool = True, dtype = None, *, key: PRNGKeyArray) ¤

Arguments:

  • input_size: The dimensionality of the input vector at each time step.
  • hidden_size: The dimensionality of the hidden state passed along between time steps.
  • use_bias: Whether to add on a bias after each update.
  • dtype: The dtype to use for all weights and biases in this GRU cell. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, input: Array, hidden: Array, *, key: Optional[PRNGKeyArray] = None) ¤

Arguments:

  • input: The input, which should be a JAX array of shape (input_size,).
  • hidden: The hidden state, which should be a JAX array of shape (hidden_size,).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The updated hidden state, which is a JAX array of shape (hidden_size,).


equinox.nn.LSTMCell (Module) ¤

A single step of a Long-Short Term Memory unit (LSTM).

Example

This is often used by wrapping it into a jax.lax.scan. For example:

class Model(Module):
    cell: LSTMCell

    def __init__(self, ...):
        self.cell = LSTMCell(...)

    def __call__(self, xs):
        scan_fn = lambda state, input: (cell(input, state), None)
        init_state = (jnp.zeros(self.cell.hidden_size),
                      jnp.zeros(self.cell.hidden_size))
        final_state, _ = jax.lax.scan(scan_fn, init_state, xs)
        return final_state
__init__(self, input_size: int, hidden_size: int, use_bias: bool = True, dtype = None, *, key: PRNGKeyArray) ¤

Arguments:

  • input_size: The dimensionality of the input vector at each time step.
  • hidden_size: The dimensionality of the hidden state passed along between time steps.
  • use_bias: Whether to add on a bias after each update.
  • dtype: The dtype to use for all weights and biases in this LSTM cell. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
__call__(self, input, hidden, *, key = None) ¤

Arguments:

  • input: The input, which should be a JAX array of shape (input_size,).
  • hidden: The hidden state, which should be a 2-tuple of JAX arrays, each of shape (hidden_size,).
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

The updated hidden state, which is a 2-tuple of JAX arrays, each of shape (hidden_size,).