Vision Transformer (ViT)ยค

This example builds a vision transformer model using Equinox, an implementation based on the paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.

In addition to this tutorial example, you may also like the ViT implementation available here in Eqxvision here.


This example will take a short while to run on a GPU.


import functools

import einops  #
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax  #

# We'll use PyTorch to load the dataset.
import torch
import torchvision
import torchvision.transforms as transforms
from jaxtyping import Array, Float, PRNGKeyArray
# Hyperparameters
lr = 0.0001
dropout_rate = 0.1
beta1 = 0.9
beta2 = 0.999
batch_size = 64
patch_size = 4
num_patches = 64
num_steps = 100000
image_size = (32, 32, 3)
embedding_dim = 512
hidden_dim = 256
num_heads = 8
num_layers = 6
height, width, channels = image_size
num_classes = 10

Let's first load the CIFAR10 dataset using torchvision

transform_train = transforms.Compose(
        transforms.RandomCrop(32, padding=4),
        transforms.Resize((height, width)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

transform_test = transforms.Compose(
        transforms.Resize((height, width)),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

train_dataset = torchvision.datasets.CIFAR10(

test_dataset = torchvision.datasets.CIFAR10(

trainloader =
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True

testloader =
    test_dataset, batch_size=batch_size, shuffle=True, drop_last=True
Now Let's start by making the patch embeddings layer that will turn images into embedded patches to be processed then by the attention layers.

class PatchEmbedding(eqx.Module):
    linear: eqx.nn.Embedding
    patch_size: int

    def __init__(
        input_channels: int,
        output_shape: int,
        patch_size: int,
        key: PRNGKeyArray,
        self.patch_size = patch_size

        self.linear = eqx.nn.Linear(
            self.patch_size**2 * input_channels,

    def __call__(
        self, x: Float[Array, "channels height width"]
    ) -> Float[Array, "num_patches embedding_dim"]:
        x = einops.rearrange(
            "c (h ph) (w pw) -> (h w) (c ph pw)",
        x = jax.vmap(self.linear)(x)

        return x

After that, we implement the attention block which is the core of the transformer architecture.

class AttentionBlock(eqx.Module):
    layer_norm1: eqx.nn.LayerNorm
    layer_norm2: eqx.nn.LayerNorm
    attention: eqx.nn.MultiheadAttention
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    dropout1: eqx.nn.Dropout
    dropout2: eqx.nn.Dropout

    def __init__(
        input_shape: int,
        hidden_dim: int,
        num_heads: int,
        dropout_rate: float,
        key: PRNGKeyArray,
        key1, key2, key3 = jr.split(key, 3)

        self.layer_norm1 = eqx.nn.LayerNorm(input_shape)
        self.layer_norm2 = eqx.nn.LayerNorm(input_shape)
        self.attention = eqx.nn.MultiheadAttention(num_heads, input_shape, key=key1)

        self.linear1 = eqx.nn.Linear(input_shape, hidden_dim, key=key2)
        self.linear2 = eqx.nn.Linear(hidden_dim, input_shape, key=key3)
        self.dropout1 = eqx.nn.Dropout(dropout_rate)
        self.dropout2 = eqx.nn.Dropout(dropout_rate)

    def __call__(
        x: Float[Array, "num_patches embedding_dim"],
        enable_dropout: bool,
        key: PRNGKeyArray,
    ) -> Float[Array, "num_patches embedding_dim"]:
        input_x = jax.vmap(self.layer_norm1)(x)
        x = x + self.attention(input_x, input_x, input_x)

        input_x = jax.vmap(self.layer_norm2)(x)
        input_x = jax.vmap(self.linear1)(input_x)
        input_x = jax.nn.gelu(input_x)

        key1, key2 = jr.split(key, num=2)

        input_x = self.dropout1(input_x, inference=not enable_dropout, key=key1)
        input_x = jax.vmap(self.linear2)(input_x)
        input_x = self.dropout2(input_x, inference=not enable_dropout, key=key2)

        x = x + input_x

        return x

Lastly, we build the full Vision Transformer model, which is composed of embeddings layers, a series of transformer blocks, and a classification head.

class VisionTransformer(eqx.Module):
    patch_embedding: PatchEmbedding
    positional_embedding: jnp.ndarray
    cls_token: jnp.ndarray
    attention_blocks: list[AttentionBlock]
    dropout: eqx.nn.Dropout
    mlp: eqx.nn.Sequential
    num_layers: int

    def __init__(
        embedding_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_layers: int,
        dropout_rate: float,
        patch_size: int,
        num_patches: int,
        num_classes: int,
        key: PRNGKeyArray,
        key1, key2, key3, key4, key5 = jr.split(key, 5)

        self.patch_embedding = PatchEmbedding(channels, embedding_dim, patch_size, key1)

        self.positional_embedding = jr.normal(key2, (num_patches + 1, embedding_dim))

        self.cls_token = jr.normal(key3, (1, embedding_dim))

        self.num_layers = num_layers

        self.attention_blocks = [
            AttentionBlock(embedding_dim, hidden_dim, num_heads, dropout_rate, key4)
            for _ in range(self.num_layers)

        self.dropout = eqx.nn.Dropout(dropout_rate)

        self.mlp = eqx.nn.Sequential(
                eqx.nn.Linear(embedding_dim, num_classes, key=key5),

    def __call__(
        x: Float[Array, "channels height width"],
        enable_dropout: bool,
        key: PRNGKeyArray,
    ) -> Float[Array, "num_classes"]:
        x = self.patch_embedding(x)

        x = jnp.concatenate((self.cls_token, x), axis=0)

        x += self.positional_embedding[
            : x.shape[0]
        ]  # Slice to the same length as x, as the positional embedding may be longer.

        dropout_key, *attention_keys = jr.split(key, num=self.num_layers + 1)

        x = self.dropout(x, inference=not enable_dropout, key=dropout_key)

        for block, attention_key in zip(self.attention_blocks, attention_keys):
            x = block(x, enable_dropout, key=attention_key)

        x = x[0]  # Select the CLS token.
        x = self.mlp(x)

        return x
def compute_grads(
    model: VisionTransformer, images: jnp.ndarray, labels: jnp.ndarray, key
    logits = jax.vmap(model, in_axes=(0, None, 0))(images, True, key)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)

    return jnp.mean(loss)

def step_model(
    model: VisionTransformer,
    optimizer: optax.GradientTransformation,
    state: optax.OptState,
    images: jnp.ndarray,
    labels: jnp.ndarray,
    loss, grads = compute_grads(model, images, labels, key)
    updates, new_state = optimizer.update(grads, state, model)

    model = eqx.apply_updates(model, updates)

    return model, new_state, loss

def train(
    model: VisionTransformer,
    optimizer: optax.GradientTransformation,
    state: optax.OptState,
    num_steps: int,
    print_every: int = 1000,
    losses = []

    def infinite_trainloader():
        while True:
            yield from data_loader

    for step, batch in zip(range(num_steps), infinite_trainloader()):
        images, labels = batch

        images = images.numpy()
        labels = labels.numpy()

        key, *subkeys = jr.split(key, num=batch_size + 1)
        subkeys = jnp.array(subkeys)

        (model, state, loss) = step_model(
            model, optimizer, state, images, labels, subkeys


        if (step % print_every) == 0 or step == num_steps - 1:
            print(f"Step: {step}/{num_steps}, Loss: {loss}.")

    return model, state, losses
key = jr.PRNGKey(2003)

model = VisionTransformer(

optimizer = optax.adamw(

state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

model, state, losses = train(model, optimizer, state, trainloader, num_steps, key=key)
Step: 0/100000, Loss: 2.5608019828796387.
Step: 1000/100000, Loss: 1.711548089981079.
Step: 2000/100000, Loss: 1.4029508829116821.
Step: 3000/100000, Loss: 1.405516505241394.
Step: 4000/100000, Loss: 1.1661641597747803.
Step: 5000/100000, Loss: 1.1351711750030518.
Step: 6000/100000, Loss: 1.11599600315094.
Step: 7000/100000, Loss: 0.796968936920166.
Step: 8000/100000, Loss: 0.6870157718658447.
Step: 9000/100000, Loss: 1.0474591255187988.
Step: 10000/100000, Loss: 0.9413787722587585.
Step: 11000/100000, Loss: 0.8514565229415894.
Step: 12000/100000, Loss: 0.6746965646743774.
Step: 13000/100000, Loss: 0.7895829677581787.
Step: 14000/100000, Loss: 0.6844460964202881.
Step: 15000/100000, Loss: 0.6571178436279297.
Step: 16000/100000, Loss: 0.5611618757247925.
Step: 17000/100000, Loss: 0.610838770866394.
Step: 18000/100000, Loss: 0.7180566787719727.
Step: 19000/100000, Loss: 0.6528561115264893.
Step: 20000/100000, Loss: 0.5517654418945312.
Step: 21000/100000, Loss: 0.6301887035369873.
Step: 22000/100000, Loss: 0.5667067766189575.
Step: 23000/100000, Loss: 0.43517759442329407.
Step: 24000/100000, Loss: 0.5348870754241943.
Step: 25000/100000, Loss: 0.44732385873794556.
Step: 26000/100000, Loss: 0.49118855595588684.
Step: 27000/100000, Loss: 0.5242345929145813.
Step: 28000/100000, Loss: 0.44588926434516907.
Step: 29000/100000, Loss: 0.23619337379932404.
Step: 30000/100000, Loss: 0.4560542702674866.
Step: 31000/100000, Loss: 0.3148268163204193.
Step: 32000/100000, Loss: 0.4813237488269806.
Step: 33000/100000, Loss: 0.40532559156417847.
Step: 34000/100000, Loss: 0.2517223358154297.
Step: 35000/100000, Loss: 0.322698712348938.
Step: 36000/100000, Loss: 0.3052283525466919.
Step: 37000/100000, Loss: 0.37322986125946045.
Step: 38000/100000, Loss: 0.27499520778656006.
Step: 39000/100000, Loss: 0.2547920346260071.
Step: 40000/100000, Loss: 0.27322614192962646.
Step: 41000/100000, Loss: 0.6049947738647461.
Step: 42000/100000, Loss: 0.28800976276397705.
Step: 43000/100000, Loss: 0.2901820242404938.
Step: 44000/100000, Loss: 0.3800655007362366.
Step: 45000/100000, Loss: 0.15261484682559967.
Step: 46000/100000, Loss: 0.17970965802669525.
Step: 47000/100000, Loss: 0.23651015758514404.
Step: 48000/100000, Loss: 0.3813527822494507.
Step: 49000/100000, Loss: 0.35252541303634644.
Step: 50000/100000, Loss: 0.16249465942382812.
Step: 51000/100000, Loss: 0.10218428075313568.
Step: 52000/100000, Loss: 0.2192973792552948.
Step: 53000/100000, Loss: 0.1880446970462799.
Step: 54000/100000, Loss: 0.14270251989364624.
Step: 55000/100000, Loss: 0.1278090476989746.
Step: 56000/100000, Loss: 0.0856819674372673.
Step: 57000/100000, Loss: 0.16201086342334747.
Step: 58000/100000, Loss: 0.20575015246868134.
Step: 59000/100000, Loss: 0.20935538411140442.
Step: 60000/100000, Loss: 0.09025183320045471.
Step: 61000/100000, Loss: 0.21367806196212769.
Step: 62000/100000, Loss: 0.06895419955253601.
Step: 63000/100000, Loss: 0.14567255973815918.
Step: 64000/100000, Loss: 0.18438486754894257.
Step: 65000/100000, Loss: 0.11639232933521271.
Step: 66000/100000, Loss: 0.06631053984165192.
Step: 67000/100000, Loss: 0.11763929575681686.
Step: 68000/100000, Loss: 0.046494871377944946.
Step: 69000/100000, Loss: 0.14044761657714844.
Step: 70000/100000, Loss: 0.1277393102645874.
Step: 71000/100000, Loss: 0.154437854886055.
Step: 72000/100000, Loss: 0.15087449550628662.
Step: 73000/100000, Loss: 0.05043340474367142.
Step: 74000/100000, Loss: 0.3183276355266571.
Step: 75000/100000, Loss: 0.15685151517391205.
Step: 76000/100000, Loss: 0.13796621561050415.
Step: 77000/100000, Loss: 0.1036764532327652.
Step: 78000/100000, Loss: 0.08222786337137222.
Step: 79000/100000, Loss: 0.1525675356388092.
Step: 80000/100000, Loss: 0.06328584253787994.
Step: 81000/100000, Loss: 0.1235610619187355.
Step: 82000/100000, Loss: 0.03093503788113594.
Step: 83000/100000, Loss: 0.07480041682720184.
Step: 84000/100000, Loss: 0.016707731410861015.
Step: 85000/100000, Loss: 0.0491723008453846.
Step: 86000/100000, Loss: 0.0650872215628624.
Step: 87000/100000, Loss: 0.08738622069358826.
Step: 88000/100000, Loss: 0.10671466588973999.
Step: 89000/100000, Loss: 0.11922930181026459.
Step: 90000/100000, Loss: 0.1234014481306076.
Step: 91000/100000, Loss: 0.08588997274637222.
Step: 92000/100000, Loss: 0.036773063242435455.
Step: 93000/100000, Loss: 0.03425668179988861.
Step: 94000/100000, Loss: 0.21202465891838074.
Step: 95000/100000, Loss: 0.26020047068595886.
Step: 96000/100000, Loss: 0.154791459441185.
Step: 97000/100000, Loss: 0.1340092271566391.
Step: 98000/100000, Loss: 0.11398129910230637.
Step: 99000/100000, Loss: 0.16246598958969116.
Step: 99999/100000, Loss: 0.04668630287051201.

And now let's see how the vision transformer performs on the CIFAR10 dataset.

accuracies = []

for batch in range(len(test_dataset) // batch_size):
    images, labels = next(iter(testloader))

    logits = jax.vmap(functools.partial(model, enable_dropout=False))(
        images.numpy(), key=jax.random.split(key, num=batch_size)

    predictions = jnp.argmax(logits, axis=-1)

    accuracy = jnp.mean(predictions == labels.numpy())


print(f"Accuracy: {np.sum(accuracies) / len(accuracies) * 100}%")
Accuracy: 79.13661858974359%

Of course this is not the best accuracy you can get on CIFAR10, but with more training and hyperparameter tuning, you can get better results using the vision transformer.