Skip to content

Train RNNยค

Here we give a complete example of what using Equinox normally looks like day-to-day.

In this example we'll train an RNN to classify clockwise vs anticlockwise spirals.

import math

import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import optax  # https://github.com/deepmind/optax

import equinox as eqx

We begin by importing the usual libraries, setting up a very simple dataloader, and generating a toy dataset of spirals.

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def get_data(dataset_size, *, key):
    t = jnp.linspace(0, 2 * math.pi, 16)
    offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
    x1 = jnp.sin(t + offset) / (1 + t)
    x2 = jnp.cos(t + offset) / (1 + t)
    y = jnp.ones((dataset_size, 1))

    half_dataset_size = dataset_size // 2
    x1 = x1.at[:half_dataset_size].multiply(-1)
    y = y.at[:half_dataset_size].set(0)
    x = jnp.stack([x1, x2], axis=-1)

    return x, y

Now for our model.

Purely by way of example, we handle the final adding on of bias ourselves, rather than letting the linear layer do it. This is just so we can demonstrate how to use custom parameters in models.

class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jnp.ndarray

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jrandom.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

And finally the training loop.

def main(
    dataset_size=10000,
    batch_size=32,
    learning_rate=3e-3,
    steps=200,
    hidden_size=16,
    depth=1,
    seed=5678,
):
    data_key, loader_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 3)
    xs, ys = get_data(dataset_size, key=data_key)
    iter_data = dataloader((xs, ys), batch_size, key=loader_key)

    model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)

    @eqx.filter_value_and_grad
    def compute_loss(model, x, y):
        pred_y = jax.vmap(model)(x)
        # Trains with respect to binary cross-entropy
        return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

    # Important for efficiency whenever you use JAX: wrap everything into a single JIT
    # region.
    @eqx.filter_jit
    def make_step(model, x, y, opt_state):
        loss, grads = compute_loss(model, x, y)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    optim = optax.adam(learning_rate)
    opt_state = optim.init(model)
    for step, (x, y) in zip(range(steps), iter_data):
        loss, model, opt_state = make_step(model, x, y, opt_state)
        loss = loss.item()
        print(f"step={step}, loss={loss}")

    pred_ys = jax.vmap(model)(xs)
    num_correct = jnp.sum((pred_ys > 0.5) == ys)
    final_accuracy = (num_correct / dataset_size).item()
    print(f"final_accuracy={final_accuracy}")

eqx.filter_value_and_grad will calculate the gradient with respect to the first argument (model). By default it will calculate gradients for all the floating-point JAX arrays and ignore everything else. For example the model parameters will be differentiated, whilst model.hidden_size is an integer and will be left alone. If you need finer control then these defaults can be adjusted; see equinox.filter_grad and equinox.filter_value_and_grad.

Likewise, by default, eqx.filter_jit will look at all the arguments passed to make_step, and automatically JIT-trace every array and JIT-static everything else. For example the model parameters and the data x and y will be traced, whilst model.hidden_size is an integer and will be static'd instead. Once again if you need finer control then these defaults can be adjusted; see equinox.filter_jit.

main()  # All right, let's run the code.
step=0, loss=0.6968629360198975
step=1, loss=0.6928662657737732
step=2, loss=0.6924071311950684
step=3, loss=0.6936538219451904
step=4, loss=0.6929900050163269
step=5, loss=0.6925624012947083
step=6, loss=0.6939840912818909
step=7, loss=0.6934851408004761
step=8, loss=0.6915387511253357
step=9, loss=0.6939077973365784
step=10, loss=0.691526472568512
step=11, loss=0.6946976184844971
step=12, loss=0.6928755640983582
step=13, loss=0.6925874948501587
step=14, loss=0.6944625377655029
step=15, loss=0.6954206228256226
step=16, loss=0.6910819411277771
step=17, loss=0.6946226358413696
step=18, loss=0.6907551288604736
step=19, loss=0.6943591833114624
step=20, loss=0.6935365200042725
step=21, loss=0.6931085586547852
step=22, loss=0.6927573680877686
step=23, loss=0.6918144226074219
step=24, loss=0.693355917930603
step=25, loss=0.6934046745300293
step=26, loss=0.6928523778915405
step=27, loss=0.6927888989448547
step=28, loss=0.6911575794219971
step=29, loss=0.6926875114440918
step=30, loss=0.6926926374435425
step=31, loss=0.6927932500839233
step=32, loss=0.6920660734176636
step=33, loss=0.6931584477424622
step=34, loss=0.6925565004348755
step=35, loss=0.6932635307312012
step=36, loss=0.6928280591964722
step=37, loss=0.6931933164596558
step=38, loss=0.6919360160827637
step=39, loss=0.6913033723831177
step=40, loss=0.6925539970397949
step=41, loss=0.6936467885971069
step=42, loss=0.6933906078338623
step=43, loss=0.6905258893966675
step=44, loss=0.693335235118866
step=45, loss=0.6936687231063843
step=46, loss=0.6922796964645386
step=47, loss=0.6942081451416016
step=48, loss=0.6924135684967041
step=49, loss=0.6919693946838379
step=50, loss=0.6911320686340332
step=51, loss=0.6917257308959961
step=52, loss=0.6902439594268799
step=53, loss=0.6989374756813049
step=54, loss=0.6880578994750977
step=55, loss=0.6932367086410522
step=56, loss=0.6903895139694214
step=57, loss=0.6951816082000732
step=58, loss=0.6881908774375916
step=59, loss=0.6912969350814819
step=60, loss=0.692997395992279
step=61, loss=0.6937721967697144
step=62, loss=0.6948648691177368
step=63, loss=0.6913964748382568
step=64, loss=0.6929829120635986
step=65, loss=0.6920725703239441
step=66, loss=0.6956514120101929
step=67, loss=0.6899707317352295
step=68, loss=0.6912500858306885
step=69, loss=0.690929114818573
step=70, loss=0.690754771232605
step=71, loss=0.6868723630905151
step=72, loss=0.6915323138237
step=73, loss=0.6870177984237671
step=74, loss=0.6885477304458618
step=75, loss=0.6867898106575012
step=76, loss=0.68553626537323
step=77, loss=0.6821398138999939
step=78, loss=0.6817562580108643
step=79, loss=0.6883955597877502
step=80, loss=0.6734539270401001
step=81, loss=0.673201322555542
step=82, loss=0.6518104076385498
step=83, loss=0.6409814357757568
step=84, loss=0.6470382213592529
step=85, loss=0.6248712539672852
step=86, loss=0.5959113240242004
step=87, loss=0.5411312580108643
step=88, loss=0.5192182064056396
step=89, loss=0.4482913017272949
step=90, loss=0.42512020468711853
step=91, loss=0.39694473147392273
step=92, loss=0.31162846088409424
step=93, loss=0.3591296076774597
step=94, loss=0.3391856551170349
step=95, loss=0.22171661257743835
step=96, loss=0.13305065035820007
step=97, loss=0.1562710404396057
step=98, loss=0.12676352262496948
step=99, loss=0.13090220093727112
step=100, loss=0.08012909442186356
step=101, loss=0.07420005649328232
step=102, loss=0.06451655924320221
step=103, loss=0.0709628015756607
step=104, loss=0.04913963004946709
step=105, loss=0.03378813713788986
step=106, loss=0.03226414695382118
step=107, loss=0.024383708834648132
step=108, loss=0.02412772923707962
step=109, loss=0.015441099181771278
step=110, loss=0.014490555040538311
step=111, loss=0.012725356966257095
step=112, loss=0.011480225250124931
step=113, loss=0.01063367910683155
step=114, loss=0.009327923879027367
step=115, loss=0.008870435878634453
step=116, loss=0.008231064304709435
step=117, loss=0.007535358890891075
step=118, loss=0.007148533593863249
step=119, loss=0.0070632947608828545
step=120, loss=0.00653859693557024
step=121, loss=0.0059097642078995705
step=122, loss=0.005809098947793245
step=123, loss=0.005686894059181213
step=124, loss=0.005436614155769348
step=125, loss=0.005461064167320728
step=126, loss=0.005191616714000702
step=127, loss=0.005134109407663345
step=128, loss=0.004667493049055338
step=129, loss=0.004276127088814974
step=130, loss=0.004375998862087727
step=131, loss=0.0046746088191866875
step=132, loss=0.0045977989211678505
step=133, loss=0.004235218744724989
step=134, loss=0.004297919105738401
step=135, loss=0.004085398279130459
step=136, loss=0.003860234282910824
step=137, loss=0.0037155605386942625
step=138, loss=0.003936016000807285
step=139, loss=0.0038189340848475695
step=140, loss=0.0035462528467178345
step=141, loss=0.003776055295020342
step=142, loss=0.0036001354455947876
step=143, loss=0.003455652855336666
step=144, loss=0.003320368705317378
step=145, loss=0.0033631217665970325
step=146, loss=0.003434734884649515
step=147, loss=0.0035251914523541927
step=148, loss=0.0030395472422242165
step=149, loss=0.003123756032437086
step=150, loss=0.0030222469940781593
step=151, loss=0.0030912940856069326
step=152, loss=0.0030278991907835007
step=153, loss=0.0028344655875116587
step=154, loss=0.0028626525308936834
step=155, loss=0.0028719506226480007
step=156, loss=0.002740566385909915
step=157, loss=0.002740482334047556
step=158, loss=0.0027753664180636406
step=159, loss=0.002754327142611146
step=160, loss=0.0026870560832321644
step=161, loss=0.0028148717246949673
step=162, loss=0.0027450912166386843
step=163, loss=0.0025830306112766266
step=164, loss=0.0025843908078968525
step=165, loss=0.002580533269792795
step=166, loss=0.0026080277748405933
step=167, loss=0.0025642546825110912
step=168, loss=0.0023213259410113096
step=169, loss=0.00243693171069026
step=170, loss=0.0025095257442444563
step=171, loss=0.00239811884239316
step=172, loss=0.002309786155819893
step=173, loss=0.0023783245123922825
step=174, loss=0.002440847223624587
step=175, loss=0.0023895162157714367
step=176, loss=0.002205097349360585
step=177, loss=0.002265633549541235
step=178, loss=0.0022901236079633236
step=179, loss=0.0022574211470782757
step=180, loss=0.0021944043692201376
step=181, loss=0.0021767830476164818
step=182, loss=0.002231168793514371
step=183, loss=0.0021353098563849926
step=184, loss=0.0021301782689988613
step=185, loss=0.002177216112613678
step=186, loss=0.0021116407588124275
step=187, loss=0.0021116193383932114
step=188, loss=0.002098802709951997
step=189, loss=0.002040596678853035
step=190, loss=0.0019703381694853306
step=191, loss=0.002052887110039592
step=192, loss=0.002063254825770855
step=193, loss=0.0020025877747684717
step=194, loss=0.0019500794587656856
step=195, loss=0.002010410651564598
step=196, loss=0.00195988523773849
step=197, loss=0.0019289047922939062
step=198, loss=0.0019055064767599106
step=199, loss=0.0019205857533961535
final_accuracy=1.0