Vision Transformer (ViT)ยค
This example builds a vision transformer model using Equinox, an implementation based on the paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
In addition to this tutorial example, you may also like the ViT implementation available here in Eqxvision here.
Warning
This example will take a short while to run on a GPU.
Reference
@inproceedings{dosovitskiy2021an,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn
and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer
and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
booktitle={International Conference on Learning Representations},
year={2021},
}
import functools
import einops # https://github.com/arogozhnikov/einops
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax # https://github.com/deepmind/optax
# We'll use PyTorch to load the dataset.
import torch
import torchvision
import torchvision.transforms as transforms
from jaxtyping import Array, Float, PRNGKeyArray
# Hyperparameters
lr = 0.0001
dropout_rate = 0.1
beta1 = 0.9
beta2 = 0.999
batch_size = 64
patch_size = 4
num_patches = 64
num_steps = 100000
image_size = (32, 32, 3)
embedding_dim = 512
hidden_dim = 256
num_heads = 8
num_layers = 6
height, width, channels = image_size
num_classes = 10
Let's first load the CIFAR10 dataset using torchvision
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.Resize((height, width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
transform_test = transforms.Compose(
[
transforms.Resize((height, width)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
train_dataset = torchvision.datasets.CIFAR10(
"CIFAR",
train=True,
download=True,
transform=transform_train,
)
test_dataset = torchvision.datasets.CIFAR10(
"CIFAR",
train=False,
download=True,
transform=transform_test,
)
trainloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
testloader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
Now Let's start by making the patch embeddings layer that will turn images into embedded patches to be processed then by the attention layers.
class PatchEmbedding(eqx.Module):
linear: eqx.nn.Embedding
patch_size: int
def __init__(
self,
input_channels: int,
output_shape: int,
patch_size: int,
key: PRNGKeyArray,
):
self.patch_size = patch_size
self.linear = eqx.nn.Linear(
self.patch_size**2 * input_channels,
output_shape,
key=key,
)
def __call__(
self, x: Float[Array, "channels height width"]
) -> Float[Array, "num_patches embedding_dim"]:
x = einops.rearrange(
x,
"c (h ph) (w pw) -> (h w) (c ph pw)",
ph=self.patch_size,
pw=self.patch_size,
)
x = jax.vmap(self.linear)(x)
return x
After that, we implement the attention block which is the core of the transformer architecture.
class AttentionBlock(eqx.Module):
layer_norm1: eqx.nn.LayerNorm
layer_norm2: eqx.nn.LayerNorm
attention: eqx.nn.MultiheadAttention
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
dropout1: eqx.nn.Dropout
dropout2: eqx.nn.Dropout
def __init__(
self,
input_shape: int,
hidden_dim: int,
num_heads: int,
dropout_rate: float,
key: PRNGKeyArray,
):
key1, key2, key3 = jr.split(key, 3)
self.layer_norm1 = eqx.nn.LayerNorm(input_shape)
self.layer_norm2 = eqx.nn.LayerNorm(input_shape)
self.attention = eqx.nn.MultiheadAttention(num_heads, input_shape, key=key1)
self.linear1 = eqx.nn.Linear(input_shape, hidden_dim, key=key2)
self.linear2 = eqx.nn.Linear(hidden_dim, input_shape, key=key3)
self.dropout1 = eqx.nn.Dropout(dropout_rate)
self.dropout2 = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
x: Float[Array, "num_patches embedding_dim"],
enable_dropout: bool,
key: PRNGKeyArray,
) -> Float[Array, "num_patches embedding_dim"]:
input_x = jax.vmap(self.layer_norm1)(x)
x = x + self.attention(input_x, input_x, input_x)
input_x = jax.vmap(self.layer_norm2)(x)
input_x = jax.vmap(self.linear1)(input_x)
input_x = jax.nn.gelu(input_x)
key1, key2 = jr.split(key, num=2)
input_x = self.dropout1(input_x, inference=not enable_dropout, key=key1)
input_x = jax.vmap(self.linear2)(input_x)
input_x = self.dropout2(input_x, inference=not enable_dropout, key=key2)
x = x + input_x
return x
Lastly, we build the full Vision Transformer model, which is composed of embeddings layers, a series of transformer blocks, and a classification head.
class VisionTransformer(eqx.Module):
patch_embedding: PatchEmbedding
positional_embedding: jnp.ndarray
cls_token: jnp.ndarray
attention_blocks: list[AttentionBlock]
dropout: eqx.nn.Dropout
mlp: eqx.nn.Sequential
num_layers: int
def __init__(
self,
embedding_dim: int,
hidden_dim: int,
num_heads: int,
num_layers: int,
dropout_rate: float,
patch_size: int,
num_patches: int,
num_classes: int,
key: PRNGKeyArray,
):
key1, key2, key3, key4, key5 = jr.split(key, 5)
self.patch_embedding = PatchEmbedding(channels, embedding_dim, patch_size, key1)
self.positional_embedding = jr.normal(key2, (num_patches + 1, embedding_dim))
self.cls_token = jr.normal(key3, (1, embedding_dim))
self.num_layers = num_layers
self.attention_blocks = [
AttentionBlock(embedding_dim, hidden_dim, num_heads, dropout_rate, key4)
for _ in range(self.num_layers)
]
self.dropout = eqx.nn.Dropout(dropout_rate)
self.mlp = eqx.nn.Sequential(
[
eqx.nn.LayerNorm(embedding_dim),
eqx.nn.Linear(embedding_dim, num_classes, key=key5),
]
)
def __call__(
self,
x: Float[Array, "channels height width"],
enable_dropout: bool,
key: PRNGKeyArray,
) -> Float[Array, "num_classes"]:
x = self.patch_embedding(x)
x = jnp.concatenate((self.cls_token, x), axis=0)
x += self.positional_embedding[
: x.shape[0]
] # Slice to the same length as x, as the positional embedding may be longer.
dropout_key, *attention_keys = jr.split(key, num=self.num_layers + 1)
x = self.dropout(x, inference=not enable_dropout, key=dropout_key)
for block, attention_key in zip(self.attention_blocks, attention_keys):
x = block(x, enable_dropout, key=attention_key)
x = x[0] # Select the CLS token.
x = self.mlp(x)
return x
@eqx.filter_value_and_grad
def compute_grads(
model: VisionTransformer, images: jnp.ndarray, labels: jnp.ndarray, key
):
logits = jax.vmap(model, in_axes=(0, None, 0))(images, True, key)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
return jnp.mean(loss)
@eqx.filter_jit
def step_model(
model: VisionTransformer,
optimizer: optax.GradientTransformation,
state: optax.OptState,
images: jnp.ndarray,
labels: jnp.ndarray,
key,
):
loss, grads = compute_grads(model, images, labels, key)
updates, new_state = optimizer.update(grads, state, model)
model = eqx.apply_updates(model, updates)
return model, new_state, loss
def train(
model: VisionTransformer,
optimizer: optax.GradientTransformation,
state: optax.OptState,
data_loader: torch.utils.data.DataLoader,
num_steps: int,
print_every: int = 1000,
key=None,
):
losses = []
def infinite_trainloader():
while True:
yield from data_loader
for step, batch in zip(range(num_steps), infinite_trainloader()):
images, labels = batch
images = images.numpy()
labels = labels.numpy()
key, *subkeys = jr.split(key, num=batch_size + 1)
subkeys = jnp.array(subkeys)
(model, state, loss) = step_model(
model, optimizer, state, images, labels, subkeys
)
losses.append(loss)
if (step % print_every) == 0 or step == num_steps - 1:
print(f"Step: {step}/{num_steps}, Loss: {loss}.")
return model, state, losses
key = jr.PRNGKey(2003)
model = VisionTransformer(
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
num_heads=num_heads,
num_layers=num_layers,
dropout_rate=dropout_rate,
patch_size=patch_size,
num_patches=num_patches,
num_classes=num_classes,
key=key,
)
optimizer = optax.adamw(
learning_rate=lr,
b1=beta1,
b2=beta2,
)
state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
model, state, losses = train(model, optimizer, state, trainloader, num_steps, key=key)
And now let's see how the vision transformer performs on the CIFAR10 dataset.
accuracies = []
for batch in range(len(test_dataset) // batch_size):
images, labels = next(iter(testloader))
logits = jax.vmap(functools.partial(model, enable_dropout=False))(
images.numpy(), key=jax.random.split(key, num=batch_size)
)
predictions = jnp.argmax(logits, axis=-1)
accuracy = jnp.mean(predictions == labels.numpy())
accuracies.append(accuracy)
print(f"Accuracy: {np.sum(accuracies) / len(accuracies) * 100}%")
Of course this is not the best accuracy you can get on CIFAR10, but with more training and hyperparameter tuning, you can get better results using the vision transformer.