Train RNNยค
Here we give a complete example of what using Equinox normally looks like day-to-day.
In this example we'll train an RNN to classify clockwise vs anticlockwise spirals.
import math
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import optax # https://github.com/deepmind/optax
import equinox as eqx
We begin by importing the usual libraries, setting up a very simple dataloader, and generating a toy dataset of spirals.
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jrandom.permutation(key, indices)
(key,) = jrandom.split(key, 1)
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def get_data(dataset_size, *, key):
t = jnp.linspace(0, 2 * math.pi, 16)
offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
x1 = jnp.sin(t + offset) / (1 + t)
x2 = jnp.cos(t + offset) / (1 + t)
y = jnp.ones((dataset_size, 1))
half_dataset_size = dataset_size // 2
x1 = x1.at[:half_dataset_size].multiply(-1)
y = y.at[:half_dataset_size].set(0)
x = jnp.stack([x1, x2], axis=-1)
return x, y
Now for our model.
Purely by way of example, we handle the final adding on of bias ourselves, rather than letting the linear
layer do it. This is just so we can demonstrate how to use custom parameters in models.
class RNN(eqx.Module):
hidden_size: int
cell: eqx.Module
linear: eqx.nn.Linear
bias: jnp.ndarray
def __init__(self, in_size, out_size, hidden_size, *, key):
ckey, lkey = jrandom.split(key)
self.hidden_size = hidden_size
self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
self.bias = jnp.zeros(out_size)
def __call__(self, input):
hidden = jnp.zeros((self.hidden_size,))
def f(carry, inp):
return self.cell(inp, carry), None
out, _ = lax.scan(f, hidden, input)
# sigmoid because we're performing binary classification
return jax.nn.sigmoid(self.linear(out) + self.bias)
And finally the training loop.
def main(
dataset_size=10000,
batch_size=32,
learning_rate=3e-3,
steps=200,
hidden_size=16,
depth=1,
seed=5678,
):
data_key, loader_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 3)
xs, ys = get_data(dataset_size, key=data_key)
iter_data = dataloader((xs, ys), batch_size, key=loader_key)
model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)
@eqx.filter_value_and_grad
def compute_loss(model, x, y):
pred_y = jax.vmap(model)(x)
# Trains with respect to binary cross-entropy
return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))
# Important for efficiency whenever you use JAX: wrap everything into a single JIT
# region.
@eqx.filter_jit
def make_step(model, x, y, opt_state):
loss, grads = compute_loss(model, x, y)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
optim = optax.adam(learning_rate)
opt_state = optim.init(model)
for step, (x, y) in zip(range(steps), iter_data):
loss, model, opt_state = make_step(model, x, y, opt_state)
loss = loss.item()
print(f"step={step}, loss={loss}")
pred_ys = jax.vmap(model)(xs)
num_correct = jnp.sum((pred_ys > 0.5) == ys)
final_accuracy = (num_correct / dataset_size).item()
print(f"final_accuracy={final_accuracy}")
eqx.filter_value_and_grad
will calculate the gradient with respect to the first argument (model
). By default it will calculate gradients for all the floating-point JAX arrays and ignore everything else. For example the model
parameters will be differentiated, whilst model.hidden_size
is an integer and will be left alone. If you need finer control then these defaults can be adjusted; see equinox.filter_grad
and equinox.filter_value_and_grad
.
Likewise, by default, eqx.filter_jit
will look at all the arguments passed to make_step
, and automatically JIT-trace every array and JIT-static everything else. For example the model
parameters and the data x
and y
will be traced, whilst model.hidden_size
is an integer and will be static'd instead. Once again if you need finer control then these defaults can be adjusted; see equinox.filter_jit
.
main() # All right, let's run the code.