Continuous Normalising Flowยค
This example is a bit of fun! It constructs a continuous normalising flow (CNF) to learn a distribution specified by a (greyscale) image. That is, the target distribution is over \(\mathbb{R}^2\), and the image specifies the (unnormalised) density at each point.
You can specify your own images, and learn your own flows.
Some example outputs from this script:
Reference:
@article{grathwohl2019ffjord,
title={{FFJORD}: {F}ree-form {C}ontinuous {D}ynamics for {S}calable {R}eversible
{G}enerative {M}odels},
author={Grathwohl, Will and Chen, Ricky T. Q. and Bettencourt, Jesse and
Sutskever, Ilya and Duvenaud, David},
journal={International Conference on Learning Representations},
year={2019},
}
This example is available as a Jupyter notebook here.
Warning
This example will need a GPU to run efficiently.
Advanced example
This is a pretty advanced example.
import math
import os
import pathlib
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import imageio
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax # https://github.com/deepmind/optax
import scipy.stats as stats
here = pathlib.Path(os.getcwd())
First let's define some vector fields. This is basically just an MLP, using tanh as the activation function and "ConcatSquash" instead of linear layers.
This is the vector field on the right hand side of the ODE.
class Func(eqx.Module):
layers: list[eqx.nn.Linear]
def __init__(self, *, data_size, width_size, depth, key, **kwargs):
super().__init__(**kwargs)
keys = jr.split(key, depth + 1)
layers = []
if depth == 0:
layers.append(
ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])
)
else:
layers.append(
ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])
)
for i in range(depth - 1):
layers.append(
ConcatSquash(
in_size=width_size, out_size=width_size, key=keys[i + 1]
)
)
layers.append(
ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])
)
self.layers = layers
def __call__(self, t, y, args):
t = jnp.asarray(t)[None]
for layer in self.layers[:-1]:
y = layer(t, y)
y = jnn.tanh(y)
y = self.layers[-1](t, y)
return y
# Credit: this layer, and some of the default hyperparameters below, are taken from the
# FFJORD repo.
class ConcatSquash(eqx.Module):
lin1: eqx.nn.Linear
lin2: eqx.nn.Linear
lin3: eqx.nn.Linear
def __init__(self, *, in_size, out_size, key, **kwargs):
super().__init__(**kwargs)
key1, key2, key3 = jr.split(key, 3)
self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)
self.lin2 = eqx.nn.Linear(1, out_size, key=key2)
self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)
def __call__(self, t, y):
return self.lin1(y) * jnn.sigmoid(self.lin2(t)) + self.lin3(t)
When training, we need to wrap our vector fields in something that also computes the change in log-density.
This can be done either approximately (using Hutchinson's trace estimator) or exactly (the divergence of the vector field; relatively computationally expensive).
def approx_logp_wrapper(t, y, args):
y, _ = y
*args, eps, func = args
fn = lambda y: func(t, y, args)
f, vjp_fn = jax.vjp(fn, y)
(eps_dfdy,) = vjp_fn(eps)
logp = jnp.sum(eps_dfdy * eps)
return f, logp
def exact_logp_wrapper(t, y, args):
y, _ = y
*args, _, func = args
fn = lambda y: func(t, y, args)
f, vjp_fn = jax.vjp(fn, y)
(size,) = y.shape # this implementation only works for 1D input
eye = jnp.eye(size)
(dfdy,) = jax.vmap(vjp_fn)(eye)
logp = jnp.trace(dfdy)
return f, logp
Wrap up the differential equation solve into a model.
def normal_log_likelihood(y):
return -0.5 * (y.size * math.log(2 * math.pi) + jnp.sum(y**2))
class CNF(eqx.Module):
funcs: list[Func]
data_size: int
exact_logp: bool
t0: float
t1: float
dt0: float
def __init__(
self,
*,
data_size,
exact_logp,
num_blocks,
width_size,
depth,
key,
**kwargs,
):
super().__init__(**kwargs)
keys = jr.split(key, num_blocks)
self.funcs = [
Func(
data_size=data_size,
width_size=width_size,
depth=depth,
key=k,
)
for k in keys
]
self.data_size = data_size
self.exact_logp = exact_logp
self.t0 = 0.0
self.t1 = 0.5
self.dt0 = 0.05
# Runs backward-in-time to train the CNF.
def train(self, y, *, key):
if self.exact_logp:
term = diffrax.ODETerm(exact_logp_wrapper)
else:
term = diffrax.ODETerm(approx_logp_wrapper)
solver = diffrax.Tsit5()
eps = jr.normal(key, y.shape)
delta_log_likelihood = 0.0
for func in reversed(self.funcs):
y = (y, delta_log_likelihood)
sol = diffrax.diffeqsolve(
term, solver, self.t1, self.t0, -self.dt0, y, (eps, func)
)
(y,), (delta_log_likelihood,) = sol.ys
return delta_log_likelihood + normal_log_likelihood(y)
# Runs forward-in-time to draw samples from the CNF.
def sample(self, *, key):
y = jr.normal(key, (self.data_size,))
for func in self.funcs:
term = diffrax.ODETerm(func)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(term, solver, self.t0, self.t1, self.dt0, y)
(y,) = sol.ys
return y
# To make illustrations, we have a variant sample method we can query to see the
# evolution of the samples during the forward solve.
def sample_flow(self, *, key):
t_so_far = self.t0
t_end = self.t0 + (self.t1 - self.t0) * len(self.funcs)
save_times = jnp.linspace(self.t0, t_end, 6)
y = jr.normal(key, (self.data_size,))
out = []
for i, func in enumerate(self.funcs):
if i == len(self.funcs) - 1:
save_ts = save_times[t_so_far <= save_times] - t_so_far
else:
save_ts = (
save_times[
(t_so_far <= save_times)
& (save_times < t_so_far + self.t1 - self.t0)
]
- t_so_far
)
t_so_far = t_so_far + self.t1 - self.t0
term = diffrax.ODETerm(func)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=save_ts)
sol = diffrax.diffeqsolve(
term, solver, self.t0, self.t1, self.dt0, y, saveat=saveat
)
out.append(sol.ys)
y = sol.ys[-1]
out = jnp.concatenate(out)
assert len(out) == 6 # number of points we saved at
return out
Alright, that's the models done. Now let's get some data.
First we have a function for taking the specified input image, and turning it into data.
def get_data(path):
# integer array of shape (height, width, channels) with values in {0, ..., 255}
img = jnp.asarray(imageio.imread(path))
if img.shape[-1] == 4:
img = img[..., :-1] # ignore alpha channel
height, width, channels = img.shape
assert channels == 3
# Convert to greyscale for simplicity.
img = img @ jnp.array([0.2989, 0.5870, 0.1140])
img = jnp.transpose(img)[:, ::-1] # (width, height)
x = jnp.arange(width, dtype=jnp.float32)
y = jnp.arange(height, dtype=jnp.float32)
x, y = jnp.broadcast_arrays(x[:, None], y[None, :])
weights = 1 - img.reshape(-1).astype(jnp.float32) / jnp.max(img)
dataset = jnp.stack(
[x.reshape(-1), y.reshape(-1)], axis=-1
) # shape (dataset_size, 2)
# For efficiency we don't bother with the particles that will have weight zero.
cond = img.reshape(-1) < 254
dataset = dataset[cond]
weights = weights[cond]
mean = jnp.mean(dataset, axis=0)
std = jnp.std(dataset, axis=0) + 1e-6
dataset = (dataset - mean) / std
return dataset, weights, mean, std, img, width, height
Now to load the data during training, we need a dataloader. In this case our dataset is small enough to fit in-memory, so we use a dataloader implementation that we can include within our overall JIT wrapper, for speed.
class DataLoader(eqx.Module):
arrays: tuple[jnp.ndarray, ...]
batch_size: int
key: jr.PRNGKey
def __check_init__(self):
dataset_size = self.arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in self.arrays)
def __call__(self, step):
dataset_size = self.arrays[0].shape[0]
num_batches = dataset_size // self.batch_size
epoch = step // num_batches
key = jr.fold_in(self.key, epoch)
perm = jr.permutation(key, jnp.arange(dataset_size))
start = (step % num_batches) * self.batch_size
slice_size = self.batch_size
batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)
return tuple(array[batch_indices] for array in self.arrays)
Bring everything together. This function is our entry point.
def main(
in_path,
out_path=None,
batch_size=500,
virtual_batches=2,
lr=1e-3,
weight_decay=1e-5,
steps=10000,
exact_logp=True,
num_blocks=2,
width_size=64,
depth=3,
print_every=100,
seed=5678,
):
if out_path is None:
out_path = here / pathlib.Path(in_path).name
else:
out_path = pathlib.Path(out_path)
key = jr.PRNGKey(seed)
model_key, loader_key, loss_key, sample_key = jr.split(key, 4)
dataset, weights, mean, std, img, width, height = get_data(in_path)
dataset_size, data_size = dataset.shape
dataloader = DataLoader((dataset, weights), batch_size, key=loader_key)
model = CNF(
data_size=data_size,
exact_logp=exact_logp,
num_blocks=num_blocks,
width_size=width_size,
depth=depth,
key=model_key,
)
optim = optax.adamw(lr, weight_decay=weight_decay)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
@eqx.filter_value_and_grad
def loss(model, data, weight, loss_key):
batch_size, _ = data.shape
noise_key, train_key = jr.split(loss_key, 2)
train_key = jr.split(key, batch_size)
data = data + jr.normal(noise_key, data.shape) * 0.5 / std
log_likelihood = jax.vmap(model.train)(data, key=train_key)
return -jnp.mean(weight * log_likelihood) # minimise negative log-likelihood
@eqx.filter_jit
def make_step(model, opt_state, step, loss_key):
# We only need gradients with respect to floating point JAX arrays, not any
# other part of our model. (e.g. the `exact_logp` flag. What would it even mean
# to differentiate that? Note that `eqx.filter_value_and_grad` does the same
# filtering by `eqx.is_inexact_array` by default.)
value = 0
grads = jax.tree_util.tree_map(
lambda leaf: jnp.zeros_like(leaf) if eqx.is_inexact_array(leaf) else None,
model,
)
# Get more accurate gradients by accumulating gradients over multiple batches.
# (Or equivalently, get lower memory requirements by splitting up a batch over
# multiple steps.)
def make_virtual_step(_, state):
value, grads, step, loss_key = state
data, weight = dataloader(step)
value_, grads_ = loss(model, data, weight, loss_key)
value = value + value_
grads = jax.tree_util.tree_map(lambda a, b: a + b, grads, grads_)
step = step + 1
loss_key = jr.split(loss_key, 1)[0]
return value, grads, step, loss_key
value, grads, step, loss_key = lax.fori_loop(
0, virtual_batches, make_virtual_step, (value, grads, step, loss_key)
)
value = value / virtual_batches
grads = jax.tree_util.tree_map(lambda a: a / virtual_batches, grads)
updates, opt_state = optim.update(
grads, opt_state, eqx.filter(model, eqx.is_inexact_array)
)
model = eqx.apply_updates(model, updates)
return value, model, opt_state, step, loss_key
step = 0
while step < steps:
start = time.time()
value, model, opt_state, step, loss_key = make_step(
model, opt_state, step, loss_key
)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")
num_samples = 5000
sample_key = jr.split(sample_key, num_samples)
samples = jax.vmap(model.sample)(key=sample_key)
sample_flows = jax.vmap(model.sample_flow, out_axes=-1)(key=sample_key)
fig, (*axs, ax, axtrue) = plt.subplots(
1,
2 + len(sample_flows),
figsize=((2 + len(sample_flows)) * 10 * height / width, 10),
)
samples = samples * std + mean
x = samples[:, 0]
y = samples[:, 1]
ax.scatter(x, y, c="black", s=2)
ax.set_xlim(-0.5, width - 0.5)
ax.set_ylim(-0.5, height - 0.5)
ax.set_aspect(height / width)
ax.set_xticks([])
ax.set_yticks([])
axtrue.imshow(img.T, origin="lower", cmap="gray")
axtrue.set_aspect(height / width)
axtrue.set_xticks([])
axtrue.set_yticks([])
x_resolution = 100
y_resolution = int(x_resolution * (height / width))
sample_flows = sample_flows * std[:, None] + mean[:, None]
x_pos, y_pos = jnp.broadcast_arrays(
jnp.linspace(-1, width + 1, x_resolution)[:, None],
jnp.linspace(-1, height + 1, y_resolution)[None, :],
)
positions = jnp.stack([jnp.ravel(x_pos), jnp.ravel(y_pos)])
densities = [stats.gaussian_kde(samples)(positions) for samples in sample_flows]
for i, (ax, density) in enumerate(zip(axs, densities)):
density = jnp.reshape(density, (x_resolution, y_resolution))
ax.imshow(density.T, origin="lower", cmap="plasma")
ax.set_aspect(height / width)
ax.set_xticks([])
ax.set_yticks([])
plt.savefig(out_path)
plt.show()
And now the following commands will reproduce the images displayed at the start.
main(in_path="../imgs/cat.png")
main(in_path="../imgs/butterfly.png", num_blocks=3)
main(in_path="../imgs/target.png", width_size=128)
Let's run the first one of those as a demonstration.
main(in_path="../imgs/cat.png")
Step: 100, Loss: 1.3118890523910522, Computation time: 1.0344874858856201
Step: 200, Loss: 1.251081943511963, Computation time: 1.0480339527130127
Step: 300, Loss: 1.2637590169906616, Computation time: 1.053438425064087
Step: 400, Loss: 1.238480567932129, Computation time: 1.1258141994476318
Step: 500, Loss: 1.2385149002075195, Computation time: 1.1150507926940918
Step: 600, Loss: 1.1769757270812988, Computation time: 1.0526695251464844
Step: 700, Loss: 1.2026259899139404, Computation time: 1.0425992012023926
Step: 800, Loss: 1.217708706855774, Computation time: 1.0489792823791504
Step: 900, Loss: 1.227858543395996, Computation time: 1.0484890937805176
Step: 1000, Loss: 1.189058780670166, Computation time: 1.0530979633331299
Step: 1100, Loss: 1.1783877611160278, Computation time: 1.0550987720489502
Step: 1200, Loss: 1.150791049003601, Computation time: 1.0446579456329346
Step: 1300, Loss: 1.221590518951416, Computation time: 1.0396397113800049
Step: 1400, Loss: 1.1535735130310059, Computation time: 1.0444588661193848
Step: 1500, Loss: 1.2072296142578125, Computation time: 1.0372049808502197
Step: 1600, Loss: 1.1677805185317993, Computation time: 1.0316574573516846
Step: 1700, Loss: 1.1480119228363037, Computation time: 1.037048101425171
Step: 1800, Loss: 1.1469142436981201, Computation time: 1.0355620384216309
Step: 1900, Loss: 1.1540181636810303, Computation time: 1.0425419807434082
Step: 2000, Loss: 1.170578122138977, Computation time: 1.034224271774292
Step: 2100, Loss: 1.1531469821929932, Computation time: 1.110149621963501
Step: 2200, Loss: 1.1332671642303467, Computation time: 1.1052398681640625
Step: 2300, Loss: 1.1843323707580566, Computation time: 1.0217394828796387
Step: 2400, Loss: 1.1606663465499878, Computation time: 1.0373609066009521
Step: 2500, Loss: 1.132057785987854, Computation time: 1.029083013534546
Step: 2600, Loss: 1.1429660320281982, Computation time: 1.0214695930480957
Step: 2700, Loss: 1.152261734008789, Computation time: 1.041821002960205
Step: 2800, Loss: 1.1637940406799316, Computation time: 1.023808240890503
Step: 2900, Loss: 1.1682878732681274, Computation time: 1.0261313915252686
Step: 3000, Loss: 1.1528184413909912, Computation time: 1.0209197998046875
Step: 3100, Loss: 1.1718814373016357, Computation time: 1.0247409343719482
Step: 3200, Loss: 1.1433460712432861, Computation time: 1.0423102378845215
Step: 3300, Loss: 1.166672706604004, Computation time: 1.025174856185913
Step: 3400, Loss: 1.1842900514602661, Computation time: 1.0576817989349365
Step: 3500, Loss: 1.1458779573440552, Computation time: 1.043668270111084
Step: 3600, Loss: 1.158961296081543, Computation time: 1.0318820476531982
Step: 3700, Loss: 1.157321810722351, Computation time: 1.042715072631836
Step: 3800, Loss: 1.1473262310028076, Computation time: 1.0452227592468262
Step: 3900, Loss: 1.1316838264465332, Computation time: 1.0491411685943604
Step: 4000, Loss: 1.149780035018921, Computation time: 1.0549867153167725
Step: 4100, Loss: 1.140995740890503, Computation time: 1.0537455081939697
Step: 4200, Loss: 1.1636855602264404, Computation time: 1.0538051128387451
Step: 4300, Loss: 1.1459300518035889, Computation time: 1.050825834274292
Step: 4400, Loss: 1.1132702827453613, Computation time: 1.0565142631530762
Step: 4500, Loss: 1.1445338726043701, Computation time: 1.0447840690612793
Step: 4600, Loss: 1.185341477394104, Computation time: 1.0458967685699463
Step: 4700, Loss: 1.1719266176223755, Computation time: 1.0501148700714111
Step: 4800, Loss: 1.1642041206359863, Computation time: 1.0497636795043945
Step: 4900, Loss: 1.1277761459350586, Computation time: 1.0412952899932861
Step: 5000, Loss: 1.1528496742248535, Computation time: 1.0490319728851318
Step: 5100, Loss: 1.1376690864562988, Computation time: 1.0476291179656982
Step: 5200, Loss: 1.1241021156311035, Computation time: 1.0485484600067139
Step: 5300, Loss: 1.186220645904541, Computation time: 1.054675817489624
Step: 5400, Loss: 1.1623786687850952, Computation time: 1.0538649559020996
Step: 5500, Loss: 1.156112551689148, Computation time: 1.0428664684295654
Step: 5600, Loss: 1.1495476961135864, Computation time: 1.0700688362121582
Step: 5700, Loss: 1.1711654663085938, Computation time: 1.056363821029663
Step: 5800, Loss: 1.15413236618042, Computation time: 1.0482234954833984
Step: 5900, Loss: 1.13923180103302, Computation time: 1.03656005859375
Step: 6000, Loss: 1.1366792917251587, Computation time: 1.0534214973449707
Step: 6100, Loss: 1.1102700233459473, Computation time: 1.0474035739898682
Step: 6200, Loss: 1.1211810111999512, Computation time: 1.053070068359375
Step: 6300, Loss: 1.1677824258804321, Computation time: 1.0389189720153809
Step: 6400, Loss: 1.152880072593689, Computation time: 1.0483548641204834
Step: 6500, Loss: 1.1546416282653809, Computation time: 1.073331594467163
Step: 6600, Loss: 1.1045620441436768, Computation time: 1.0458285808563232
Step: 6700, Loss: 1.1473908424377441, Computation time: 1.0490500926971436
Step: 6800, Loss: 1.1420358419418335, Computation time: 1.0393445491790771
Step: 6900, Loss: 1.1263903379440308, Computation time: 1.0471012592315674
Step: 7000, Loss: 1.1546120643615723, Computation time: 1.120964765548706
Step: 7100, Loss: 1.1234941482543945, Computation time: 1.1098055839538574
Step: 7200, Loss: 1.1735379695892334, Computation time: 1.112914800643921
Step: 7300, Loss: 1.1549310684204102, Computation time: 1.1177945137023926
Step: 7400, Loss: 1.1674282550811768, Computation time: 1.0430011749267578
Step: 7500, Loss: 1.1249209642410278, Computation time: 1.0475146770477295
Step: 7600, Loss: 1.149993896484375, Computation time: 1.0517895221710205
Step: 7700, Loss: 1.126546859741211, Computation time: 1.0522327423095703
Step: 7800, Loss: 1.0991778373718262, Computation time: 1.034656286239624
Step: 7900, Loss: 1.1408506631851196, Computation time: 1.0303537845611572
Step: 8000, Loss: 1.1567769050598145, Computation time: 1.0465099811553955
Step: 8100, Loss: 1.1207897663116455, Computation time: 1.0353443622589111
Step: 8200, Loss: 1.1423345804214478, Computation time: 1.0374336242675781
Step: 8300, Loss: 1.1438696384429932, Computation time: 1.0480520725250244
Step: 8400, Loss: 1.184295654296875, Computation time: 1.046989917755127
Step: 8500, Loss: 1.1350500583648682, Computation time: 1.0438358783721924
Step: 8600, Loss: 1.1440303325653076, Computation time: 1.0357015132904053
Step: 8700, Loss: 1.1465599536895752, Computation time: 1.0545375347137451
Step: 8800, Loss: 1.116767406463623, Computation time: 1.0341439247131348
Step: 8900, Loss: 1.1343653202056885, Computation time: 1.0588762760162354
Step: 9000, Loss: 1.1371710300445557, Computation time: 1.0479660034179688
Step: 9100, Loss: 1.125238299369812, Computation time: 1.0390236377716064
Step: 9200, Loss: 1.1323764324188232, Computation time: 1.0565376281738281
Step: 9300, Loss: 1.1228868961334229, Computation time: 1.0413126945495605
Step: 9400, Loss: 1.1537754535675049, Computation time: 1.0376060009002686
Step: 9500, Loss: 1.157130241394043, Computation time: 1.0577881336212158
Step: 9600, Loss: 1.1420390605926514, Computation time: 1.058633804321289
Step: 9700, Loss: 1.14859139919281, Computation time: 1.0608363151550293
Step: 9800, Loss: 1.0945243835449219, Computation time: 1.0353436470031738
Step: 9900, Loss: 1.1538352966308594, Computation time: 1.0435452461242676
Step: 10000, Loss: 1.1333094835281372, Computation time: 1.0515074729919434