import math
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax # https://github.com/deepmind/optax
We begin by importing the usual libraries, setting up a very simple dataloader, and generating a toy dataset of spirals.
def dataloader(arrays, batch_size):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = np.arange(dataset_size)
while True:
perm = np.random.permutation(indices)
start = 0
end = batch_size
while end <= dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def get_data(dataset_size, *, key):
t = jnp.linspace(0, 2 * math.pi, 16)
offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
x1 = jnp.sin(t + offset) / (1 + t)
x2 = jnp.cos(t + offset) / (1 + t)
y = jnp.ones((dataset_size, 1))
half_dataset_size = dataset_size // 2
x1 = x1.at[:half_dataset_size].multiply(-1)
y = y.at[:half_dataset_size].set(0)
x = jnp.stack([x1, x2], axis=-1)
return x, y
Now for our model.
Purely by way of example, we handle the final adding on of bias ourselves, rather than letting the linear
layer do it. This is just so we can demonstrate how to use custom parameters in models.
class RNN(eqx.Module):
hidden_size: int
cell: eqx.Module
linear: eqx.nn.Linear
bias: jax.Array
def __init__(self, in_size, out_size, hidden_size, *, key):
ckey, lkey = jrandom.split(key)
self.hidden_size = hidden_size
self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
self.bias = jnp.zeros(out_size)
def __call__(self, input):
hidden = jnp.zeros((self.hidden_size,))
def f(carry, inp):
return self.cell(inp, carry), None
out, _ = lax.scan(f, hidden, input)
# sigmoid because we're performing binary classification
return jax.nn.sigmoid(self.linear(out) + self.bias)
And finally the training loop.
def main(
dataset_size=10000,
batch_size=32,
learning_rate=3e-3,
steps=200,
hidden_size=16,
depth=1,
seed=5678,
):
data_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 2)
xs, ys = get_data(dataset_size, key=data_key)
iter_data = dataloader((xs, ys), batch_size)
model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)
@eqx.filter_value_and_grad
def compute_loss(model, x, y):
pred_y = jax.vmap(model)(x)
# Trains with respect to binary cross-entropy
return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))
# Important for efficiency whenever you use JAX: wrap everything into a single JIT
# region.
@eqx.filter_jit
def make_step(model, x, y, opt_state):
loss, grads = compute_loss(model, x, y)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
optim = optax.adam(learning_rate)
opt_state = optim.init(model)
for step, (x, y) in zip(range(steps), iter_data):
loss, model, opt_state = make_step(model, x, y, opt_state)
loss = loss.item()
print(f"step={step}, loss={loss}")
pred_ys = jax.vmap(model)(xs)
num_correct = jnp.sum((pred_ys > 0.5) == ys)
final_accuracy = (num_correct / dataset_size).item()
print(f"final_accuracy={final_accuracy}")
eqx.filter_value_and_grad
will calculate the gradient with respect to all floating-point arrays in the first argument (model
). In this case the model
parameters will be differentiated, whilst model.hidden_size
is an integer and will get None
as its gradient.
Likewise, eqx.filter_jit
will look at all the arguments passed to make_step
, and automatically JIT-trace every array and JIT-static everything else. In this case the model
parameters and the data x
and y
will be traced, whilst model.hidden_size
is an integer and will be static'd instead.
main() # All right, let's run the code.