# U-Net implementationยค

This is an advanced example, providing an implementation of a U-Net architecture.

This version is intended for use in a score-based diffusion, so it accepts a t argument.

This example is available as a Jupyter notebook here

Author: Ben Walker (https://github.com/Benjamin-Walker)

import math
from collections.abc import Callable
from typing import Optional, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
from einops import rearrange

class SinusoidalPosEmb(eqx.Module):
emb: jax.Array

def __init__(self, dim):
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
self.emb = jnp.exp(jnp.arange(half_dim) * -emb)

def __call__(self, x):
emb = x * self.emb
emb = jnp.concatenate((jnp.sin(emb), jnp.cos(emb)), axis=-1)
return emb

class LinearTimeSelfAttention(eqx.Module):
group_norm: eqx.nn.GroupNorm
to_qkv: eqx.nn.Conv2d
to_out: eqx.nn.Conv2d

def __init__(
self,
dim,
key,
):
keys = jax.random.split(key, 2)
self.group_norm = eqx.nn.GroupNorm(min(dim // 4, 32), dim)
self.to_qkv = eqx.nn.Conv2d(dim, hidden_dim * 3, 1, key=keys[0])
self.to_out = eqx.nn.Conv2d(hidden_dim, dim, 1, key=keys[1])

def __call__(self, x):
c, h, w = x.shape
x = self.group_norm(x)
qkv = self.to_qkv(x)
q, k, v = rearrange(
)
k = jax.nn.softmax(k, axis=-1)
context = jnp.einsum("hdn,hen->hde", k, v)
out = jnp.einsum("hde,hdn->hen", context, q)
out = rearrange(
)
return self.to_out(out)

def upsample_2d(y, factor=2):
C, H, W = y.shape
y = jnp.reshape(y, [C, H, 1, W, 1])
y = jnp.tile(y, [1, 1, factor, 1, factor])
return jnp.reshape(y, [C, H * factor, W * factor])

def downsample_2d(y, factor=2):
C, H, W = y.shape
y = jnp.reshape(y, [C, H // factor, factor, W // factor, factor])
return jnp.mean(y, axis=[2, 4])

def exact_zip(*args):
_len = len(args[0])
for arg in args:
assert len(arg) == _len
return zip(*args)

def key_split_allowing_none(key):
if key is None:
return key, None
else:
return jr.split(key)

class Residual(eqx.Module):
fn: LinearTimeSelfAttention

def __init__(self, fn):
self.fn = fn

def __call__(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x

class ResnetBlock(eqx.Module):
dim_out: int
is_biggan: bool
up: bool
down: bool
dropout_rate: float
time_emb_dim: int
mlp_layers: list[Union[Callable, eqx.nn.Linear]]
scaling: Union[None, Callable, eqx.nn.ConvTranspose2d, eqx.nn.Conv2d]
block1_groupnorm: eqx.nn.GroupNorm
block1_conv: eqx.nn.Conv2d
block2_layers: list[
Union[eqx.nn.GroupNorm, eqx.nn.Dropout, eqx.nn.Conv2d, Callable]
]
res_conv: eqx.nn.Conv2d
attn: Optional[Residual]

def __init__(
self,
dim_in,
dim_out,
is_biggan,
up,
down,
time_emb_dim,
dropout_rate,
is_attn,
*,
key,
):
keys = jax.random.split(key, 7)
self.dim_out = dim_out
self.is_biggan = is_biggan
self.up = up
self.down = down
self.dropout_rate = dropout_rate
self.time_emb_dim = time_emb_dim

self.mlp_layers = [
jax.nn.silu,
eqx.nn.Linear(time_emb_dim, dim_out, key=keys[0]),
]
self.block1_groupnorm = eqx.nn.GroupNorm(min(dim_in // 4, 32), dim_in)
self.block1_conv = eqx.nn.Conv2d(dim_in, dim_out, 3, padding=1, key=keys[1])
self.block2_layers = [
eqx.nn.GroupNorm(min(dim_out // 4, 32), dim_out),
jax.nn.silu,
eqx.nn.Dropout(dropout_rate),
]

assert not self.up or not self.down

if is_biggan:
if self.up:
self.scaling = upsample_2d
elif self.down:
self.scaling = downsample_2d
else:
self.scaling = None
else:
if self.up:
self.scaling = eqx.nn.ConvTranspose2d(
dim_in,
dim_in,
kernel_size=4,
stride=2,
key=keys[3],
)
elif self.down:
self.scaling = eqx.nn.Conv2d(
dim_in,
dim_in,
kernel_size=3,
stride=2,
key=keys[4],
)
else:
self.scaling = None
# For DDPM Yang use their own custom layer called NIN, which is
# equivalent to a 1x1 conv
self.res_conv = eqx.nn.Conv2d(dim_in, dim_out, kernel_size=1, key=keys[5])

if is_attn:
self.attn = Residual(
LinearTimeSelfAttention(
dim_out,
key=keys[6],
)
)
else:
self.attn = None

def __call__(self, x, t, *, key):
C, _, _ = x.shape
# In DDPM, each set of resblocks ends with an up/down sampling. In
# biggan there is a final resblock after the up/downsampling. In this
# code, the biggan approach is taken for both.
# norm -> nonlinearity -> up/downsample -> conv follows Yang
# https://github.dev/yang-song/score_sde/blob/main/models/layerspp.py
h = jax.nn.silu(self.block1_groupnorm(x))
if self.up or self.down:
h = self.scaling(h)  # pyright: ignore
x = self.scaling(x)  # pyright: ignore
h = self.block1_conv(h)

for layer in self.mlp_layers:
t = layer(t)
h = h + t[..., None, None]
for layer in self.block2_layers:
# Precisely 1 dropout layer in block2_layers which requires a key.
if isinstance(layer, eqx.nn.Dropout):
h = layer(h, key=key)
else:
h = layer(h)

if C != self.dim_out or self.up or self.down:
x = self.res_conv(x)

out = (h + x) / jnp.sqrt(2)
if self.attn is not None:
out = self.attn(out)
return out

class UNet(eqx.Module):
time_pos_emb: SinusoidalPosEmb
mlp: eqx.nn.MLP
first_conv: eqx.nn.Conv2d
down_res_blocks: list[list[ResnetBlock]]
mid_block1: ResnetBlock
mid_block2: ResnetBlock
ups_res_blocks: list[list[ResnetBlock]]
final_conv_layers: list[Union[Callable, eqx.nn.LayerNorm, eqx.nn.Conv2d]]

def __init__(
self,
data_shape: tuple[int, int, int],
is_biggan: bool,
dim_mults: list[int],
hidden_size: int,
dropout_rate: float,
num_res_blocks: int,
attn_resolutions: list[int],
*,
key,
):
keys = jax.random.split(key, 7)
del key

data_channels, in_height, in_width = data_shape

dims = [hidden_size] + [hidden_size * m for m in dim_mults]
in_out = list(exact_zip(dims[:-1], dims[1:]))

self.time_pos_emb = SinusoidalPosEmb(hidden_size)
self.mlp = eqx.nn.MLP(
hidden_size,
hidden_size,
4 * hidden_size,
1,
activation=jax.nn.silu,
key=keys[0],
)
self.first_conv = eqx.nn.Conv2d(
)

h, w = in_height, in_width
self.down_res_blocks = []
num_keys = len(in_out) * num_res_blocks - 1
keys_resblock = jr.split(keys[2], num_keys)
i = 0
for ind, (dim_in, dim_out) in enumerate(in_out):
if h in attn_resolutions and w in attn_resolutions:
is_attn = True
else:
is_attn = False
res_blocks = [
ResnetBlock(
dim_in=dim_in,
dim_out=dim_out,
is_biggan=is_biggan,
up=False,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=is_attn,
key=keys_resblock[i],
)
]
i += 1
for _ in range(num_res_blocks - 2):
res_blocks.append(
ResnetBlock(
dim_in=dim_out,
dim_out=dim_out,
is_biggan=is_biggan,
up=False,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=is_attn,
key=keys_resblock[i],
)
)
i += 1
if ind < (len(in_out) - 1):
res_blocks.append(
ResnetBlock(
dim_in=dim_out,
dim_out=dim_out,
is_biggan=is_biggan,
up=False,
down=True,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=is_attn,
key=keys_resblock[i],
)
)
i += 1
h, w = h // 2, w // 2
self.down_res_blocks.append(res_blocks)
assert i == num_keys

mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(
dim_in=mid_dim,
dim_out=mid_dim,
is_biggan=is_biggan,
up=False,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=True,
key=keys[3],
)
self.mid_block2 = ResnetBlock(
dim_in=mid_dim,
dim_out=mid_dim,
is_biggan=is_biggan,
up=False,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=False,
key=keys[4],
)

self.ups_res_blocks = []
num_keys = len(in_out) * (num_res_blocks + 1) - 1
keys_resblock = jr.split(keys[5], num_keys)
i = 0
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
if h in attn_resolutions and w in attn_resolutions:
is_attn = True
else:
is_attn = False
res_blocks = []
for _ in range(num_res_blocks - 1):
res_blocks.append(
ResnetBlock(
dim_in=dim_out * 2,
dim_out=dim_out,
is_biggan=is_biggan,
up=False,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=is_attn,
key=keys_resblock[i],
)
)
i += 1
res_blocks.append(
ResnetBlock(
dim_in=dim_out + dim_in,
dim_out=dim_in,
is_biggan=is_biggan,
up=False,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=is_attn,
key=keys_resblock[i],
)
)
i += 1
if ind < (len(in_out) - 1):
res_blocks.append(
ResnetBlock(
dim_in=dim_in,
dim_out=dim_in,
is_biggan=is_biggan,
up=True,
down=False,
time_emb_dim=hidden_size,
dropout_rate=dropout_rate,
is_attn=is_attn,
key=keys_resblock[i],
)
)
i += 1
h, w = h * 2, w * 2

self.ups_res_blocks.append(res_blocks)
assert i == num_keys

self.final_conv_layers = [
eqx.nn.GroupNorm(min(hidden_size // 4, 32), hidden_size),
jax.nn.silu,
eqx.nn.Conv2d(hidden_size, data_channels, 1, key=keys[6]),
]

def __call__(self, t, y, *, key=None):
t = self.time_pos_emb(t)
t = self.mlp(t)
h = self.first_conv(y)
hs = [h]
for res_blocks in self.down_res_blocks:
for res_block in res_blocks:
key, subkey = key_split_allowing_none(key)
h = res_block(h, t, key=subkey)
hs.append(h)

key, subkey = key_split_allowing_none(key)
h = self.mid_block1(h, t, key=subkey)
key, subkey = key_split_allowing_none(key)
h = self.mid_block2(h, t, key=subkey)

for res_blocks in self.ups_res_blocks:
for res_block in res_blocks:
key, subkey = key_split_allowing_none(key)
if res_block.up:
h = res_block(h, t, key=subkey)
else:
h = res_block(jnp.concatenate((h, hs.pop()), axis=0), t, key=subkey)

assert len(hs) == 0

for layer in self.final_conv_layers:
h = layer(h)
return h