Skip to content

Vision Transformer (ViT)ยค

This example builds a vision transformer model using Equinox, an implementation based on the paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.

In addition to this tutorial example, you may also like the ViT implementation available here in Eqxvision here.

Warning

This example will take a short while to run on a GPU.

Reference

arXiv link

@inproceedings{dosovitskiy2021an,
    title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author={Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn
            and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer
            and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    booktitle={International Conference on Learning Representations},
    year={2021},
}
import functools

import einops  # https://github.com/arogozhnikov/einops
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax  # https://github.com/deepmind/optax

# We'll use PyTorch to load the dataset.
import torch
import torchvision
import torchvision.transforms as transforms
from jaxtyping import Array, Float, PRNGKeyArray
# Hyperparameters
lr = 0.0001
dropout_rate = 0.1
beta1 = 0.9
beta2 = 0.999
batch_size = 64
patch_size = 4
num_patches = 64
num_steps = 100000
image_size = (32, 32, 3)
embedding_dim = 512
hidden_dim = 256
num_heads = 8
num_layers = 6
height, width, channels = image_size
num_classes = 10

Let's first load the CIFAR10 dataset using torchvision

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.Resize((height, width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize((height, width)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_dataset = torchvision.datasets.CIFAR10(
    "CIFAR",
    train=True,
    download=True,
    transform=transform_train,
)

test_dataset = torchvision.datasets.CIFAR10(
    "CIFAR",
    train=False,
    download=True,
    transform=transform_test,
)

trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)

testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
Files already downloaded and verified
Files already downloaded and verified

Now Let's start by making the patch embeddings layer that will turn images into embedded patches to be processed then by the attention layers.

class PatchEmbedding(eqx.Module):
    linear: eqx.nn.Embedding
    patch_size: int

    def __init__(
        self,
        input_channels: int,
        output_shape: int,
        patch_size: int,
        key: PRNGKeyArray,
    ):
        self.patch_size = patch_size

        self.linear = eqx.nn.Linear(
            self.patch_size**2 * input_channels,
            output_shape,
            key=key,
        )

    def __call__(
        self, x: Float[Array, "channels height width"]
    ) -> Float[Array, "num_patches embedding_dim"]:
        x = einops.rearrange(
            x,
            "c (h ph) (w pw) -> (h w) (c ph pw)",
            ph=self.patch_size,
            pw=self.patch_size,
        )
        x = jax.vmap(self.linear)(x)

        return x

After that, we implement the attention block which is the core of the transformer architecture.

class AttentionBlock(eqx.Module):
    layer_norm1: eqx.nn.LayerNorm
    layer_norm2: eqx.nn.LayerNorm
    attention: eqx.nn.MultiheadAttention
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    dropout1: eqx.nn.Dropout
    dropout2: eqx.nn.Dropout

    def __init__(
        self,
        input_shape: int,
        hidden_dim: int,
        num_heads: int,
        dropout_rate: float,
        key: PRNGKeyArray,
    ):
        key1, key2, key3 = jr.split(key, 3)

        self.layer_norm1 = eqx.nn.LayerNorm(input_shape)
        self.layer_norm2 = eqx.nn.LayerNorm(input_shape)
        self.attention = eqx.nn.MultiheadAttention(num_heads, input_shape, key=key1)

        self.linear1 = eqx.nn.Linear(input_shape, hidden_dim, key=key2)
        self.linear2 = eqx.nn.Linear(hidden_dim, input_shape, key=key3)
        self.dropout1 = eqx.nn.Dropout(dropout_rate)
        self.dropout2 = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        x: Float[Array, "num_patches embedding_dim"],
        enable_dropout: bool,
        key: PRNGKeyArray,
    ) -> Float[Array, "num_patches embedding_dim"]:
        input_x = jax.vmap(self.layer_norm1)(x)
        x = x + self.attention(input_x, input_x, input_x)

        input_x = jax.vmap(self.layer_norm2)(x)
        input_x = jax.vmap(self.linear1)(input_x)
        input_x = jax.nn.gelu(input_x)

        key1, key2 = jr.split(key, num=2)

        input_x = self.dropout1(input_x, inference=not enable_dropout, key=key1)
        input_x = jax.vmap(self.linear2)(input_x)
        input_x = self.dropout2(input_x, inference=not enable_dropout, key=key2)

        x = x + input_x

        return x

Lastly, we build the full Vision Transformer model, which is composed of embeddings layers, a series of transformer blocks, and a classification head.

class VisionTransformer(eqx.Module):
    patch_embedding: PatchEmbedding
    positional_embedding: jnp.ndarray
    cls_token: jnp.ndarray
    attention_blocks: list[AttentionBlock]
    dropout: eqx.nn.Dropout
    mlp: eqx.nn.Sequential
    num_layers: int

    def __init__(
        self,
        embedding_dim: int,
        hidden_dim: int,
        num_heads: int,
        num_layers: int,
        dropout_rate: float,
        patch_size: int,
        num_patches: int,
        num_classes: int,
        key: PRNGKeyArray,
    ):
        key1, key2, key3, key4, key5 = jr.split(key, 5)

        self.patch_embedding = PatchEmbedding(channels, embedding_dim, patch_size, key1)

        self.positional_embedding = jr.normal(key2, (num_patches + 1, embedding_dim))

        self.cls_token = jr.normal(key3, (1, embedding_dim))

        self.num_layers = num_layers

        self.attention_blocks = [
            AttentionBlock(embedding_dim, hidden_dim, num_heads, dropout_rate, key4)
            for _ in range(self.num_layers)
        ]

        self.dropout = eqx.nn.Dropout(dropout_rate)

        self.mlp = eqx.nn.Sequential(
            [
                eqx.nn.LayerNorm(embedding_dim),
                eqx.nn.Linear(embedding_dim, num_classes, key=key5),
            ]
        )

    def __call__(
        self,
        x: Float[Array, "channels height width"],
        enable_dropout: bool,
        key: PRNGKeyArray,
    ) -> Float[Array, "num_classes"]:
        x = self.patch_embedding(x)

        x = jnp.concatenate((self.cls_token, x), axis=0)

        x += self.positional_embedding[
            : x.shape[0]
        ]  # Slice to the same length as x, as the positional embedding may be longer.

        dropout_key, *attention_keys = jr.split(key, num=self.num_layers + 1)

        x = self.dropout(x, inference=not enable_dropout, key=dropout_key)

        for block, attention_key in zip(self.attention_blocks, attention_keys):
            x = block(x, enable_dropout, key=attention_key)

        x = x[0]  # Select the CLS token.
        x = self.mlp(x)

        return x
@eqx.filter_value_and_grad
def compute_grads(
    model: VisionTransformer, images: jnp.ndarray, labels: jnp.ndarray, key
):
    logits = jax.vmap(model, in_axes=(0, None, 0))(images, True, key)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)

    return jnp.mean(loss)


@eqx.filter_jit
def step_model(
    model: VisionTransformer,
    optimizer: optax.GradientTransformation,
    state: optax.OptState,
    images: jnp.ndarray,
    labels: jnp.ndarray,
    key,
):
    loss, grads = compute_grads(model, images, labels, key)
    updates, new_state = optimizer.update(grads, state, model)

    model = eqx.apply_updates(model, updates)

    return model, new_state, loss


def train(
    model: VisionTransformer,
    optimizer: optax.GradientTransformation,
    state: optax.OptState,
    data_loader: torch.utils.data.DataLoader,
    num_steps: int,
    print_every: int = 1000,
    key=None,
):
    losses = []

    def infinite_trainloader():
        while True:
            yield from data_loader

    for step, batch in zip(range(num_steps), infinite_trainloader()):
        images, labels = batch

        images = images.numpy()
        labels = labels.numpy()

        key, *subkeys = jr.split(key, num=batch_size + 1)
        subkeys = jnp.array(subkeys)

        (model, state, loss) = step_model(
            model, optimizer, state, images, labels, subkeys
        )

        losses.append(loss)

        if (step % print_every) == 0 or step == num_steps - 1:
            print(f"Step: {step}/{num_steps}, Loss: {loss}.")

    return model, state, losses
key = jr.PRNGKey(2003)

model = VisionTransformer(
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    dropout_rate=dropout_rate,
    patch_size=patch_size,
    num_patches=num_patches,
    num_classes=num_classes,
    key=key,
)

optimizer = optax.adamw(
    learning_rate=lr,
    b1=beta1,
    b2=beta2,
)

state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

model, state, losses = train(model, optimizer, state, trainloader, num_steps, key=key)
Step: 0/100000, Loss: 2.5608019828796387.
Step: 1000/100000, Loss: 1.711548089981079.
Step: 2000/100000, Loss: 1.4029508829116821.
Step: 3000/100000, Loss: 1.405516505241394.
Step: 4000/100000, Loss: 1.1661641597747803.
Step: 5000/100000, Loss: 1.1351711750030518.
Step: 6000/100000, Loss: 1.11599600315094.
Step: 7000/100000, Loss: 0.796968936920166.
Step: 8000/100000, Loss: 0.6870157718658447.
Step: 9000/100000, Loss: 1.0474591255187988.
Step: 10000/100000, Loss: 0.9413787722587585.
Step: 11000/100000, Loss: 0.8514565229415894.
Step: 12000/100000, Loss: 0.6746965646743774.
Step: 13000/100000, Loss: 0.7895829677581787.
Step: 14000/100000, Loss: 0.6844460964202881.
Step: 15000/100000, Loss: 0.6571178436279297.
Step: 16000/100000, Loss: 0.5611618757247925.
Step: 17000/100000, Loss: 0.610838770866394.
Step: 18000/100000, Loss: 0.7180566787719727.
Step: 19000/100000, Loss: 0.6528561115264893.
Step: 20000/100000, Loss: 0.5517654418945312.
Step: 21000/100000, Loss: 0.6301887035369873.
Step: 22000/100000, Loss: 0.5667067766189575.
Step: 23000/100000, Loss: 0.43517759442329407.
Step: 24000/100000, Loss: 0.5348870754241943.
Step: 25000/100000, Loss: 0.44732385873794556.
Step: 26000/100000, Loss: 0.49118855595588684.
Step: 27000/100000, Loss: 0.5242345929145813.
Step: 28000/100000, Loss: 0.44588926434516907.
Step: 29000/100000, Loss: 0.23619337379932404.
Step: 30000/100000, Loss: 0.4560542702674866.
Step: 31000/100000, Loss: 0.3148268163204193.
Step: 32000/100000, Loss: 0.4813237488269806.
Step: 33000/100000, Loss: 0.40532559156417847.
Step: 34000/100000, Loss: 0.2517223358154297.
Step: 35000/100000, Loss: 0.322698712348938.
Step: 36000/100000, Loss: 0.3052283525466919.
Step: 37000/100000, Loss: 0.37322986125946045.
Step: 38000/100000, Loss: 0.27499520778656006.
Step: 39000/100000, Loss: 0.2547920346260071.
Step: 40000/100000, Loss: 0.27322614192962646.
Step: 41000/100000, Loss: 0.6049947738647461.
Step: 42000/100000, Loss: 0.28800976276397705.
Step: 43000/100000, Loss: 0.2901820242404938.
Step: 44000/100000, Loss: 0.3800655007362366.
Step: 45000/100000, Loss: 0.15261484682559967.
Step: 46000/100000, Loss: 0.17970965802669525.
Step: 47000/100000, Loss: 0.23651015758514404.
Step: 48000/100000, Loss: 0.3813527822494507.
Step: 49000/100000, Loss: 0.35252541303634644.
Step: 50000/100000, Loss: 0.16249465942382812.
Step: 51000/100000, Loss: 0.10218428075313568.
Step: 52000/100000, Loss: 0.2192973792552948.
Step: 53000/100000, Loss: 0.1880446970462799.
Step: 54000/100000, Loss: 0.14270251989364624.
Step: 55000/100000, Loss: 0.1278090476989746.
Step: 56000/100000, Loss: 0.0856819674372673.
Step: 57000/100000, Loss: 0.16201086342334747.
Step: 58000/100000, Loss: 0.20575015246868134.
Step: 59000/100000, Loss: 0.20935538411140442.
Step: 60000/100000, Loss: 0.09025183320045471.
Step: 61000/100000, Loss: 0.21367806196212769.
Step: 62000/100000, Loss: 0.06895419955253601.
Step: 63000/100000, Loss: 0.14567255973815918.
Step: 64000/100000, Loss: 0.18438486754894257.
Step: 65000/100000, Loss: 0.11639232933521271.
Step: 66000/100000, Loss: 0.06631053984165192.
Step: 67000/100000, Loss: 0.11763929575681686.
Step: 68000/100000, Loss: 0.046494871377944946.
Step: 69000/100000, Loss: 0.14044761657714844.
Step: 70000/100000, Loss: 0.1277393102645874.
Step: 71000/100000, Loss: 0.154437854886055.
Step: 72000/100000, Loss: 0.15087449550628662.
Step: 73000/100000, Loss: 0.05043340474367142.
Step: 74000/100000, Loss: 0.3183276355266571.
Step: 75000/100000, Loss: 0.15685151517391205.
Step: 76000/100000, Loss: 0.13796621561050415.
Step: 77000/100000, Loss: 0.1036764532327652.
Step: 78000/100000, Loss: 0.08222786337137222.
Step: 79000/100000, Loss: 0.1525675356388092.
Step: 80000/100000, Loss: 0.06328584253787994.
Step: 81000/100000, Loss: 0.1235610619187355.
Step: 82000/100000, Loss: 0.03093503788113594.
Step: 83000/100000, Loss: 0.07480041682720184.
Step: 84000/100000, Loss: 0.016707731410861015.
Step: 85000/100000, Loss: 0.0491723008453846.
Step: 86000/100000, Loss: 0.0650872215628624.
Step: 87000/100000, Loss: 0.08738622069358826.
Step: 88000/100000, Loss: 0.10671466588973999.
Step: 89000/100000, Loss: 0.11922930181026459.
Step: 90000/100000, Loss: 0.1234014481306076.
Step: 91000/100000, Loss: 0.08588997274637222.
Step: 92000/100000, Loss: 0.036773063242435455.
Step: 93000/100000, Loss: 0.03425668179988861.
Step: 94000/100000, Loss: 0.21202465891838074.
Step: 95000/100000, Loss: 0.26020047068595886.
Step: 96000/100000, Loss: 0.154791459441185.
Step: 97000/100000, Loss: 0.1340092271566391.
Step: 98000/100000, Loss: 0.11398129910230637.
Step: 99000/100000, Loss: 0.16246598958969116.
Step: 99999/100000, Loss: 0.04668630287051201.

And now let's see how the vision transformer performs on the CIFAR10 dataset.

accuracies = []

for batch in range(len(test_dataset) // batch_size):
    images, labels = next(iter(testloader))

    logits = jax.vmap(functools.partial(model, enable_dropout=False))(
        images.numpy(), key=jax.random.split(key, num=batch_size)
    )

    predictions = jnp.argmax(logits, axis=-1)

    accuracy = jnp.mean(predictions == labels.numpy())

    accuracies.append(accuracy)

print(f"Accuracy: {np.sum(accuracies) / len(accuracies) * 100}%")
Accuracy: 79.13661858974359%

Of course this is not the best accuracy you can get on CIFAR10, but with more training and hyperparameter tuning, you can get better results using the vision transformer.