Skip to content

Train RNNยค

This is an introductory example. We demonstrate what using Equinox normally looks like day-to-day.

Here, we'll train an RNN to classify clockwise vs anticlockwise spirals.

This example is available as a Jupyter notebook here.

import math

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax  # https://github.com/deepmind/optax

We begin by importing the usual libraries, setting up a very simple dataloader, and generating a toy dataset of spirals.

def dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def get_data(dataset_size, *, key):
    t = jnp.linspace(0, 2 * math.pi, 16)
    offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
    x1 = jnp.sin(t + offset) / (1 + t)
    x2 = jnp.cos(t + offset) / (1 + t)
    y = jnp.ones((dataset_size, 1))

    half_dataset_size = dataset_size // 2
    x1 = x1.at[:half_dataset_size].multiply(-1)
    y = y.at[:half_dataset_size].set(0)
    x = jnp.stack([x1, x2], axis=-1)

    return x, y

Now for our model.

Purely by way of example, we handle the final adding on of bias ourselves, rather than letting the linear layer do it. This is just so we can demonstrate how to use custom parameters in models.

class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jrandom.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

And finally the training loop.

def main(
    dataset_size=10000,
    batch_size=32,
    learning_rate=3e-3,
    steps=200,
    hidden_size=16,
    depth=1,
    seed=5678,
):
    data_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 2)
    xs, ys = get_data(dataset_size, key=data_key)
    iter_data = dataloader((xs, ys), batch_size)

    model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)

    @eqx.filter_value_and_grad
    def compute_loss(model, x, y):
        pred_y = jax.vmap(model)(x)
        # Trains with respect to binary cross-entropy
        return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

    # Important for efficiency whenever you use JAX: wrap everything into a single JIT
    # region.
    @eqx.filter_jit
    def make_step(model, x, y, opt_state):
        loss, grads = compute_loss(model, x, y)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    optim = optax.adam(learning_rate)
    opt_state = optim.init(model)
    for step, (x, y) in zip(range(steps), iter_data):
        loss, model, opt_state = make_step(model, x, y, opt_state)
        loss = loss.item()
        print(f"step={step}, loss={loss}")

    pred_ys = jax.vmap(model)(xs)
    num_correct = jnp.sum((pred_ys > 0.5) == ys)
    final_accuracy = (num_correct / dataset_size).item()
    print(f"final_accuracy={final_accuracy}")

eqx.filter_value_and_grad will calculate the gradient with respect to all floating-point arrays in the first argument (model). In this case the model parameters will be differentiated, whilst model.hidden_size is an integer and will get None as its gradient.

Likewise, eqx.filter_jit will look at all the arguments passed to make_step, and automatically JIT-trace every array and JIT-static everything else. In this case the model parameters and the data x and y will be traced, whilst model.hidden_size is an integer and will be static'd instead.

main()  # All right, let's run the code.
step=0, loss=0.6816999316215515
step=1, loss=0.7202574014663696
step=2, loss=0.6925007104873657
step=3, loss=0.689198911190033
step=4, loss=0.6808685064315796
step=5, loss=0.7059305906295776
step=6, loss=0.6922754049301147
step=7, loss=0.6842439770698547
step=8, loss=0.6972116231918335
step=9, loss=0.7047306299209595
step=10, loss=0.6993851661682129
step=11, loss=0.6921849846839905
step=12, loss=0.6844913959503174
step=13, loss=0.6941200494766235
step=14, loss=0.6870629787445068
step=15, loss=0.6922240257263184
step=16, loss=0.6966875195503235
step=17, loss=0.7021255493164062
step=18, loss=0.6913468241691589
step=19, loss=0.6915531158447266
step=20, loss=0.6906869411468506
step=21, loss=0.6945821046829224
step=22, loss=0.6963403820991516
step=23, loss=0.6893304586410522
step=24, loss=0.6923031210899353
step=25, loss=0.6952496767044067
step=26, loss=0.6937462687492371
step=27, loss=0.6946915984153748
step=28, loss=0.6912715435028076
step=29, loss=0.6945470571517944
step=30, loss=0.6928573250770569
step=31, loss=0.6918295621871948
step=32, loss=0.6926039457321167
step=33, loss=0.691811203956604
step=34, loss=0.696336567401886
step=35, loss=0.693527340888977
step=36, loss=0.6909832954406738
step=37, loss=0.6898350715637207
step=38, loss=0.693118691444397
step=39, loss=0.6962690353393555
step=40, loss=0.6943768262863159
step=41, loss=0.6929119229316711
step=42, loss=0.6921533942222595
step=43, loss=0.6970506906509399
step=44, loss=0.6914128065109253
step=45, loss=0.6925110220909119
step=46, loss=0.6876767873764038
step=47, loss=0.6977562308311462
step=48, loss=0.6887734532356262
step=49, loss=0.6956733465194702
step=50, loss=0.6988524198532104
step=51, loss=0.6972949504852295
step=52, loss=0.6935367584228516
step=53, loss=0.6899304389953613
step=54, loss=0.6940433979034424
step=55, loss=0.6932569742202759
step=56, loss=0.6964170932769775
step=57, loss=0.6952816843986511
step=58, loss=0.6925933361053467
step=59, loss=0.700016975402832
step=60, loss=0.6929588317871094
step=61, loss=0.6919406652450562
step=62, loss=0.6893216371536255
step=63, loss=0.6881398558616638
step=64, loss=0.6941375136375427
step=65, loss=0.6908596754074097
step=66, loss=0.6938614845275879
step=67, loss=0.6939255595207214
step=68, loss=0.691447377204895
step=69, loss=0.6932423114776611
step=70, loss=0.6937750577926636
step=71, loss=0.691257119178772
step=72, loss=0.6900532245635986
step=73, loss=0.6922309398651123
step=74, loss=0.6899502277374268
step=75, loss=0.6930654048919678
step=76, loss=0.6942011117935181
step=77, loss=0.6899413466453552
step=78, loss=0.6950610876083374
step=79, loss=0.6900242567062378
step=80, loss=0.691747784614563
step=81, loss=0.6899303793907166
step=82, loss=0.6910462379455566
step=83, loss=0.69475257396698
step=84, loss=0.6886341571807861
step=85, loss=0.6912660598754883
step=86, loss=0.6889529824256897
step=87, loss=0.6940121054649353
step=88, loss=0.6970347762107849
step=89, loss=0.687224268913269
step=90, loss=0.6900577545166016
step=91, loss=0.6913183927536011
step=92, loss=0.6916753649711609
step=93, loss=0.6899659633636475
step=94, loss=0.6911211013793945
step=95, loss=0.694290041923523
step=96, loss=0.7031664848327637
step=97, loss=0.6912339925765991
step=98, loss=0.6968348026275635
step=99, loss=0.6970176100730896
step=100, loss=0.6857004165649414
step=101, loss=0.6842451095581055
step=102, loss=0.6882964968681335
step=103, loss=0.6855384111404419
step=104, loss=0.6909692287445068
step=105, loss=0.6905874013900757
step=106, loss=0.6900045871734619
step=107, loss=0.6865564584732056
step=108, loss=0.6820229887962341
step=109, loss=0.6879786849021912
step=110, loss=0.6853011846542358
step=111, loss=0.68475741147995
step=112, loss=0.682267427444458
step=113, loss=0.6880433559417725
step=114, loss=0.6814002990722656
step=115, loss=0.6823583841323853
step=116, loss=0.6794727444648743
step=117, loss=0.6785068511962891
step=118, loss=0.6811013221740723
step=119, loss=0.6747442483901978
step=120, loss=0.6660218238830566
step=121, loss=0.6700407266616821
step=122, loss=0.6526561975479126
step=123, loss=0.6608943939208984
step=124, loss=0.6293025612831116
step=125, loss=0.6483496427536011
step=126, loss=0.6219364404678345
step=127, loss=0.5961954593658447
step=128, loss=0.6002600193023682
step=129, loss=0.5647848844528198
step=130, loss=0.5256890058517456
step=131, loss=0.510317325592041
step=132, loss=0.47984960675239563
step=133, loss=0.5084915161132812
step=134, loss=0.4301827549934387
step=135, loss=0.4290550649166107
step=136, loss=0.3755859136581421
step=137, loss=0.2937808036804199
step=138, loss=0.26023393869400024
step=139, loss=0.23048073053359985
step=140, loss=0.21439003944396973
step=141, loss=0.1652923822402954
step=142, loss=0.1283920854330063
step=143, loss=0.10732141137123108
step=144, loss=0.09533026814460754
step=145, loss=0.0801059827208519
step=146, loss=0.06879423558712006
step=147, loss=0.05884774401783943
step=148, loss=0.04997169226408005
step=149, loss=0.04442397877573967
step=150, loss=0.03894098848104477
step=151, loss=0.03257341682910919
step=152, loss=0.029269497841596603
step=153, loss=0.025602124631404877
step=154, loss=0.022208761423826218
step=155, loss=0.019779304042458534
step=156, loss=0.0190683975815773
step=157, loss=0.01669464260339737
step=158, loss=0.015229889191687107
step=159, loss=0.013572589494287968
step=160, loss=0.013138247653841972
step=161, loss=0.011538836173713207
step=162, loss=0.011130412109196186
step=163, loss=0.010209350846707821
step=164, loss=0.00980051327496767
step=165, loss=0.009073751047253609
step=166, loss=0.008881419897079468
step=167, loss=0.008286576718091965
step=168, loss=0.007832735776901245
step=169, loss=0.007236967794597149
step=170, loss=0.006913297809660435
step=171, loss=0.006962948478758335
step=172, loss=0.0067381164990365505
step=173, loss=0.0065834359265863895
step=174, loss=0.006057928316295147
step=175, loss=0.006013630423694849
step=176, loss=0.005771873984485865
step=177, loss=0.005476493388414383
step=178, loss=0.005490635521709919
step=179, loss=0.0053573474287986755
step=180, loss=0.005200999788939953
step=181, loss=0.005007486790418625
step=182, loss=0.004953362978994846
step=183, loss=0.004767540842294693
step=184, loss=0.004584995098412037
step=185, loss=0.004679872654378414
step=186, loss=0.004541940055787563
step=187, loss=0.004449024796485901
step=188, loss=0.004398198798298836
step=189, loss=0.004358283709734678
step=190, loss=0.004335914738476276
step=191, loss=0.004185144789516926
step=192, loss=0.00412558950483799
step=193, loss=0.0040787700563669205
step=194, loss=0.003923933021724224
step=195, loss=0.0039162710309028625
step=196, loss=0.0038443428929895163
step=197, loss=0.0037653278559446335
step=198, loss=0.0037484394852072
step=199, loss=0.0037215303163975477
final_accuracy=1.0