Skip to content

Convolutional Neural Network on MNIST¤

This is an introductory example, intended for those who are new to both JAX and Equinox. This example builds a CNN to classify MNIST, and demonstrates:

  • How to create a custom neural network using Equinox;
  • When and why to use the eqx.filter_{...} functions;
  • What your neural network looks like "under the hood" (like a PyTree).

The JAX ecosystem is build around a number of libraries, that each do a single thing. So in addition to Equinox (for model building), this example also uses Optax to train the network, and jaxtyping to provide type annotations.

This example is available as a Jupyter notebook here.

What's the difference between JAX and Equinox?

JAX is the underlying library for numerical routines: it provides JIT compilation, autodifferentiation, and operations like matrix multiplication etc. However it deliberately does not provide anything to do with any particular use case, like neural networks -- these are delegated to downstream libraries.

Equinox is one such library. It provides neural network operations, plus many more advanced features. Go back and take a look at the All of Equinox page once you've finished this example!

import equinox as eqx
import jax
import jax.numpy as jnp
import optax  #
import torch  #
import torchvision  #
from jaxtyping import Array, Float, Int, PyTree  #
# Hyperparameters

STEPS = 300
SEED = 5678

key = jax.random.PRNGKey(SEED)

The dataset¤

We load the MNIST dataset using PyTorch.

JAX deliberately does not provide any built-in datasets or dataloaders! This is because there are already some well-curated datasets and dataloaders available elsewhere -- so it is common to use JAX alongside another library to provide these.

  • If you like PyTorch, then see here for a guide to its DataSet and DataLoader classes.
  • If you like TensorFlow, then see here for a guide to its pipeline.
  • If you like NumPy -- which is a good choice for small in-memory datasets -- then see here for an example.
normalise_data = torchvision.transforms.Compose(
        torchvision.transforms.Normalize((0.5,), (0.5,)),
train_dataset = torchvision.datasets.MNIST(
test_dataset = torchvision.datasets.MNIST(
trainloader =
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
testloader =
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64
(64, 1, 28, 28)
[0 7 1 2 2 2 7 5 5 6 3 7 4 6 5 8 3 3 2 4 0 9 2 9 1 1 0 7 9 9 7 0 8 1 8 1 4
 4 4 5 1 3 8 3 3 0 0 8 3 6 1 0 0 9 2 4 6 6 0 7 7 1 8 7]

We can see that our input has the shape (64, 1, 28, 28). 64 is the batch size, 1 is the number of input channels (MNIST is greyscale) and 28x28 are the height and width of the image in pixels. The label is of shape (64,), and each value is a number from 0 to 9.

The model¤

Our convolutional neural network (CNN) will store a list of all its operations. There is no explicit requirement to do it that way, it's simply convenient for this example.

These operations can be any JAX operation. Some of these will be Equinox's built in layers (e.g. convolutions), and some of them will be functions from JAX itself (e.g. jax.nn.relu as an activation function).

class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.Linear(1728, 512, key=key2),
            eqx.nn.Linear(512, 64, key=key3),
            eqx.nn.Linear(64, 10, key=key4),

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

key, subkey = jax.random.split(key, 2)
model = CNN(subkey)

As with everything in Equinox, our model is a PyTree. That is to say, just a nested collection of objects. Some of these object are JAX arrays; for example model.layers[0].weight is the kernel of our convolution. And some of these objects are essentially arbitrary Python objects; for example model.layers[-1] is jax.nn.log_softmax, which is just a Python function like any other.

Equinox provides a nice __repr__ for its modules, so we can just print out what our PyTree looks like:

      kernel_size=(4, 4),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      operation=<function max>,
      kernel_size=(2, 2),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
    <wrapped function relu>,
    <wrapped function ravel>,
    <wrapped function sigmoid>,
    <wrapped function relu>,
    <wrapped function log_softmax>

Given some data, we can perform inference on our model.

(Note that here we are using JAX operation outside of a JIT'd region. This is very slow! You shouldn't write it like this except when exploring things in a notebook.)

def loss(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)

def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape)  # batch of predictions
(64, 10)


In the next cells we can see an example of when we should use the filter methods provided by Equinox. For instance, the following code generates an error:

# This is an error!
jax.value_and_grad(loss)(model, dummy_x, dummy_y)
TypeError                                 Traceback (most recent call last)
Cell In[8], line 2
      1 # This is an error!
----> 2 jax.value_and_grad(loss)(model, dummy_x, dummy_y)

    [... skipping hidden 3 frame]

File ~/miniconda3/envs/jax38/lib/python3.8/site-packages/jax/_src/, in check_arg(arg)
    641 def check_arg(arg):
    642   if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
--> 643     raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
    644                     "JAX type.")

TypeError: Argument '<function max at 0x7dcabbebee50>' of type <class 'function'> is not a valid JAX type.

When we write jax.value_and_grad(loss)(model, ...), we are asking JAX to differentiate the function loss with respect to its first argument model. (To compute the gradients on its parameters.)

However, model includes several things that aren't parameters! Look back up at the PyTree print-out from earlier, and we see lines like e.g. <wrapped function relu> -- this isn't a parameter and isn't even an array.

We need to split our model into the bit we want to differentiate (its parameters), and the bit we don't (everything else). If we want to, then we can do this manually:

# This will work!
params, static = eqx.partition(model, eqx.is_array)

def loss2(params, static, x, y):
    model = eqx.combine(params, static)
    return loss(model, x, y)

loss_value, grads = jax.value_and_grad(loss2)(params, static, dummy_x, dummy_y)

It's quite common that all arrays represent parameters, so that "the bit we want to differentiate" really just means "all arrays". As such, Equinox provides a convenient wrapper eqx.filter_value_and_grad, which does the above partitioning-and-combining for us: it automatically splits things into arrays and non-arrays, and then differentiates with respect to all arrays in the first argument:

# This will work too!
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)

The Equinox eqx.filter_{...} functions are essentially the same as the corresponding JAX functions, and they're just smart enough to handle non-arrays without raising an error. So if you're unsure, you can simply always use the Equinox filter functions.


As with most machine learning tasks, we need some methods to evaluate our model on some testdata. For this we create the following functions.

Notice that we used eqx.filter_jit instead of jax.jit since as usual our model contains non-arrays (e.g. those relu activation functions), and those aren't arrays that can be differentiated.

loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!

def compute_accuracy(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)
def evaluate(model: CNN, testloader:
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)
evaluate(model, testloader)
(Array(2.3077886, dtype=float32), Array(0.10111465, dtype=float32))


Now it's time to train our model using Optax!

optim = optax.adamw(LEARNING_RATE)
def train(
    model: CNN,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    def make_step(
        model: CNN,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
    return model
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)
step=0, train_loss=2.32609486579895, test_loss=2.299288749694824, test_accuracy=0.10161226242780685
step=30, train_loss=2.2150564193725586, test_loss=2.1842434406280518, test_accuracy=0.4199840724468231
step=60, train_loss=1.9649711847305298, test_loss=1.9000366926193237, test_accuracy=0.618829607963562
step=90, train_loss=1.4872171878814697, test_loss=1.4725608825683594, test_accuracy=0.6449044346809387
step=120, train_loss=1.050407886505127, test_loss=1.0521366596221924, test_accuracy=0.7916003465652466
step=150, train_loss=0.8088936805725098, test_loss=0.7538199424743652, test_accuracy=0.8578821420669556
step=180, train_loss=0.6006966829299927, test_loss=0.574236273765564, test_accuracy=0.865545392036438
step=210, train_loss=0.33910322189331055, test_loss=0.4889797866344452, test_accuracy=0.8819665312767029
step=240, train_loss=0.33334940671920776, test_loss=0.44309598207473755, test_accuracy=0.8862460255622864
step=270, train_loss=0.3595482110977173, test_loss=0.3812088072299957, test_accuracy=0.897292971611023
step=299, train_loss=0.35001736879348755, test_loss=0.3582405149936676, test_accuracy=0.9039610028266907

This is actually a pretty bad final accuracy, as MNIST is so easy. Try tweaking this example to make it better!

Next steps

Hopefully this example has given you a taste of how models are built using JAX and Equinox. For next steps, take a look at the JAX documentation for more information on JAX, the All of Equinox page for a summary of everything Equinox can do, or training an RNN for another example.