Continuous Normalising Flowยค
This example is a bit of fun! It constructs a continuous normalising flow (CNF) to learn a distribution specified by a (greyscale) image. That is, the target distribution is over \(\mathbb{R}^2\), and the image specifies the (unnormalised) density at each point.
You can specify your own images, and learn your own flows.
Some example outputs from this script:
 
 

Reference:
@article{grathwohl2019ffjord,
    title={{FFJORD}: {F}ree-form {C}ontinuous {D}ynamics for {S}calable {R}eversible
           {G}enerative {M}odels},
    author={Grathwohl, Will and Chen, Ricky T. Q. and Bettencourt, Jesse and
            Sutskever, Ilya and Duvenaud, David},
    journal={International Conference on Learning Representations},
    year={2019},
}
This example is available as a Jupyter notebook here.
Warning
This example will need a GPU to run efficiently.
Advanced example
This is a pretty advanced example.
import math
import os
import pathlib
import time
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import imageio
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import scipy.stats as stats
here = pathlib.Path(os.getcwd())
First let's define some vector fields. This is basically just an MLP, using tanh as the activation function and "ConcatSquash" instead of linear layers.
This is the vector field on the right hand side of the ODE.
class Func(eqx.Module):
    layers: list[eqx.nn.Linear]
    def __init__(self, *, data_size, width_size, depth, key, **kwargs):
        super().__init__(**kwargs)
        keys = jr.split(key, depth + 1)
        layers = []
        if depth == 0:
            layers.append(
                ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])
            )
        else:
            layers.append(
                ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])
            )
            for i in range(depth - 1):
                layers.append(
                    ConcatSquash(
                        in_size=width_size, out_size=width_size, key=keys[i + 1]
                    )
                )
            layers.append(
                ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])
            )
        self.layers = layers
    def __call__(self, t, y, args):
        t = jnp.asarray(t)[None]
        for layer in self.layers[:-1]:
            y = layer(t, y)
            y = jnn.tanh(y)
        y = self.layers[-1](t, y)
        return y
# Credit: this layer, and some of the default hyperparameters below, are taken from the
# FFJORD repo.
class ConcatSquash(eqx.Module):
    lin1: eqx.nn.Linear
    lin2: eqx.nn.Linear
    lin3: eqx.nn.Linear
    def __init__(self, *, in_size, out_size, key, **kwargs):
        super().__init__(**kwargs)
        key1, key2, key3 = jr.split(key, 3)
        self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)
        self.lin2 = eqx.nn.Linear(1, out_size, key=key2)
        self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)
    def __call__(self, t, y):
        return self.lin1(y) * jnn.sigmoid(self.lin2(t)) + self.lin3(t)
When training, we need to wrap our vector fields in something that also computes the change in log-density.
This can be done either approximately (using Hutchinson's trace estimator) or exactly (the divergence of the vector field; relatively computationally expensive).
def approx_logp_wrapper(t, y, args):
    y, _ = y
    *args, eps, func = args
    fn = lambda y: func(t, y, args)
    f, vjp_fn = jax.vjp(fn, y)
    (eps_dfdy,) = vjp_fn(eps)
    logp = jnp.sum(eps_dfdy * eps)
    return f, logp
def exact_logp_wrapper(t, y, args):
    y, _ = y
    *args, _, func = args
    fn = lambda y: func(t, y, args)
    f, vjp_fn = jax.vjp(fn, y)
    (size,) = y.shape  # this implementation only works for 1D input
    eye = jnp.eye(size)
    (dfdy,) = jax.vmap(vjp_fn)(eye)
    logp = jnp.trace(dfdy)
    return f, logp
Wrap up the differential equation solve into a model.
def normal_log_likelihood(y):
    return -0.5 * (y.size * math.log(2 * math.pi) + jnp.sum(y**2))
class CNF(eqx.Module):
    funcs: list[Func]
    data_size: int
    exact_logp: bool
    t0: float
    t1: float
    dt0: float
    def __init__(
        self,
        *,
        data_size,
        exact_logp,
        num_blocks,
        width_size,
        depth,
        key,
        **kwargs,
    ):
        super().__init__(**kwargs)
        keys = jr.split(key, num_blocks)
        self.funcs = [
            Func(
                data_size=data_size,
                width_size=width_size,
                depth=depth,
                key=k,
            )
            for k in keys
        ]
        self.data_size = data_size
        self.exact_logp = exact_logp
        self.t0 = 0.0
        self.t1 = 0.5
        self.dt0 = 0.05
    # Runs backward-in-time to train the CNF.
    def train(self, y, *, key):
        if self.exact_logp:
            term = diffrax.ODETerm(exact_logp_wrapper)
        else:
            term = diffrax.ODETerm(approx_logp_wrapper)
        solver = diffrax.Tsit5()
        eps = jr.normal(key, y.shape)
        delta_log_likelihood = 0.0
        for func in reversed(self.funcs):
            y = (y, delta_log_likelihood)
            sol = diffrax.diffeqsolve(
                term, solver, self.t1, self.t0, -self.dt0, y, (eps, func)
            )
            (y,), (delta_log_likelihood,) = sol.ys
        return delta_log_likelihood + normal_log_likelihood(y)
    # Runs forward-in-time to draw samples from the CNF.
    def sample(self, *, key):
        y = jr.normal(key, (self.data_size,))
        for func in self.funcs:
            term = diffrax.ODETerm(func)
            solver = diffrax.Tsit5()
            sol = diffrax.diffeqsolve(term, solver, self.t0, self.t1, self.dt0, y)
            (y,) = sol.ys
        return y
    # To make illustrations, we have a variant sample method we can query to see the
    # evolution of the samples during the forward solve.
    def sample_flow(self, *, key):
        t_so_far = self.t0
        t_end = self.t0 + (self.t1 - self.t0) * len(self.funcs)
        save_times = jnp.linspace(self.t0, t_end, 6)
        y = jr.normal(key, (self.data_size,))
        out = []
        for i, func in enumerate(self.funcs):
            if i == len(self.funcs) - 1:
                save_ts = save_times[t_so_far <= save_times] - t_so_far
            else:
                save_ts = (
                    save_times[
                        (t_so_far <= save_times)
                        & (save_times < t_so_far + self.t1 - self.t0)
                    ]
                    - t_so_far
                )
                t_so_far = t_so_far + self.t1 - self.t0
            term = diffrax.ODETerm(func)
            solver = diffrax.Tsit5()
            saveat = diffrax.SaveAt(ts=save_ts)
            sol = diffrax.diffeqsolve(
                term, solver, self.t0, self.t1, self.dt0, y, saveat=saveat
            )
            out.append(sol.ys)
            y = sol.ys[-1]
        out = jnp.concatenate(out)
        assert len(out) == 6  # number of points we saved at
        return out
Alright, that's the models done. Now let's get some data.
First we have a function for taking the specified input image, and turning it into data.
def get_data(path):
    # integer array of shape (height, width, channels) with values in {0, ..., 255}
    img = jnp.asarray(imageio.imread(path))
    if img.shape[-1] == 4:
        img = img[..., :-1]  # ignore alpha channel
    height, width, channels = img.shape
    assert channels == 3
    # Convert to greyscale for simplicity.
    img = img @ jnp.array([0.2989, 0.5870, 0.1140])
    img = jnp.transpose(img)[:, ::-1]  # (width, height)
    x = jnp.arange(width, dtype=jnp.float32)
    y = jnp.arange(height, dtype=jnp.float32)
    x, y = jnp.broadcast_arrays(x[:, None], y[None, :])
    weights = 1 - img.reshape(-1).astype(jnp.float32) / jnp.max(img)
    dataset = jnp.stack(
        [x.reshape(-1), y.reshape(-1)], axis=-1
    )  # shape (dataset_size, 2)
    # For efficiency we don't bother with the particles that will have weight zero.
    cond = img.reshape(-1) < 254
    dataset = dataset[cond]
    weights = weights[cond]
    mean = jnp.mean(dataset, axis=0)
    std = jnp.std(dataset, axis=0) + 1e-6
    dataset = (dataset - mean) / std
    return dataset, weights, mean, std, img, width, height
Now to load the data during training, we need a dataloader. In this case our dataset is small enough to fit in-memory, so we use a dataloader implementation that we can include within our overall JIT wrapper, for speed.
class DataLoader(eqx.Module):
    arrays: tuple[jnp.ndarray, ...]
    batch_size: int
    key: jr.PRNGKey
    def __check_init__(self):
        dataset_size = self.arrays[0].shape[0]
        assert all(array.shape[0] == dataset_size for array in self.arrays)
    def __call__(self, step):
        dataset_size = self.arrays[0].shape[0]
        num_batches = dataset_size // self.batch_size
        epoch = step // num_batches
        key = jr.fold_in(self.key, epoch)
        perm = jr.permutation(key, jnp.arange(dataset_size))
        start = (step % num_batches) * self.batch_size
        slice_size = self.batch_size
        batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)
        return tuple(array[batch_indices] for array in self.arrays)
Bring everything together. This function is our entry point.
def main(
    in_path,
    out_path=None,
    batch_size=500,
    virtual_batches=2,
    lr=1e-3,
    weight_decay=1e-5,
    steps=10000,
    exact_logp=True,
    num_blocks=2,
    width_size=64,
    depth=3,
    print_every=100,
    seed=5678,
):
    if out_path is None:
        out_path = here / pathlib.Path(in_path).name
    else:
        out_path = pathlib.Path(out_path)
    key = jr.PRNGKey(seed)
    model_key, loader_key, loss_key, sample_key = jr.split(key, 4)
    dataset, weights, mean, std, img, width, height = get_data(in_path)
    dataset_size, data_size = dataset.shape
    dataloader = DataLoader((dataset, weights), batch_size, key=loader_key)
    model = CNF(
        data_size=data_size,
        exact_logp=exact_logp,
        num_blocks=num_blocks,
        width_size=width_size,
        depth=depth,
        key=model_key,
    )
    optim = optax.adamw(lr, weight_decay=weight_decay)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    @eqx.filter_value_and_grad
    def loss(model, data, weight, loss_key):
        batch_size, _ = data.shape
        noise_key, train_key = jr.split(loss_key, 2)
        train_key = jr.split(train_key, batch_size)
        data = data + jr.normal(noise_key, data.shape) * 0.5 / std
        log_likelihood = jax.vmap(model.train)(data, key=train_key)
        return -jnp.mean(weight * log_likelihood)  # minimise negative log-likelihood
    @eqx.filter_jit
    def make_step(model, opt_state, step, loss_key):
        # We only need gradients with respect to floating point JAX arrays, not any
        # other part of our model. (e.g. the `exact_logp` flag. What would it even mean
        # to differentiate that? Note that `eqx.filter_value_and_grad` does the same
        # filtering by `eqx.is_inexact_array` by default.)
        value = 0
        grads = jax.tree_util.tree_map(
            lambda leaf: jnp.zeros_like(leaf) if eqx.is_inexact_array(leaf) else None,
            model,
        )
        # Get more accurate gradients by accumulating gradients over multiple batches.
        # (Or equivalently, get lower memory requirements by splitting up a batch over
        # multiple steps.)
        def make_virtual_step(_, state):
            value, grads, step, loss_key = state
            data, weight = dataloader(step)
            value_, grads_ = loss(model, data, weight, loss_key)
            value = value + value_
            grads = jax.tree_util.tree_map(lambda a, b: a + b, grads, grads_)
            step = step + 1
            loss_key = jr.split(loss_key, 1)[0]
            return value, grads, step, loss_key
        value, grads, step, loss_key = lax.fori_loop(
            0, virtual_batches, make_virtual_step, (value, grads, step, loss_key)
        )
        value = value / virtual_batches
        grads = jax.tree_util.tree_map(lambda a: a / virtual_batches, grads)
        updates, opt_state = optim.update(
            grads, opt_state, eqx.filter(model, eqx.is_inexact_array)
        )
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state, step, loss_key
    step = 0
    while step < steps:
        start = time.time()
        value, model, opt_state, step, loss_key = make_step(
            model, opt_state, step, loss_key
        )
        end = time.time()
        if (step % print_every) == 0 or step == steps - 1:
            print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")
    num_samples = 5000
    sample_key = jr.split(sample_key, num_samples)
    samples = jax.vmap(model.sample)(key=sample_key)
    sample_flows = jax.vmap(model.sample_flow, out_axes=-1)(key=sample_key)
    fig, (*axs, ax, axtrue) = plt.subplots(
        1,
        2 + len(sample_flows),
        figsize=((2 + len(sample_flows)) * 10 * height / width, 10),
    )
    samples = samples * std + mean
    x = samples[:, 0]
    y = samples[:, 1]
    ax.scatter(x, y, c="black", s=2)
    ax.set_xlim(-0.5, width - 0.5)
    ax.set_ylim(-0.5, height - 0.5)
    ax.set_aspect(height / width)
    ax.set_xticks([])
    ax.set_yticks([])
    axtrue.imshow(img.T, origin="lower", cmap="gray")
    axtrue.set_aspect(height / width)
    axtrue.set_xticks([])
    axtrue.set_yticks([])
    x_resolution = 100
    y_resolution = int(x_resolution * (height / width))
    sample_flows = sample_flows * std[:, None] + mean[:, None]
    x_pos, y_pos = jnp.broadcast_arrays(
        jnp.linspace(-1, width + 1, x_resolution)[:, None],
        jnp.linspace(-1, height + 1, y_resolution)[None, :],
    )
    positions = jnp.stack([jnp.ravel(x_pos), jnp.ravel(y_pos)])
    densities = [stats.gaussian_kde(samples)(positions) for samples in sample_flows]
    for i, (ax, density) in enumerate(zip(axs, densities)):
        density = jnp.reshape(density, (x_resolution, y_resolution))
        ax.imshow(density.T, origin="lower", cmap="plasma")
        ax.set_aspect(height / width)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(out_path)
    plt.show()
And now the following commands will reproduce the images displayed at the start.
main(in_path="../imgs/cat.png")
main(in_path="../imgs/butterfly.png", num_blocks=3)
main(in_path="../imgs/target.png", width_size=128)
Let's run the first one of those as a demonstration.
main(in_path="../imgs/cat.png")
Step: 100, Loss: 1.326553225517273, Computation time: 0.005589962005615234
Step: 200, Loss: 1.2921397686004639, Computation time: 0.004528999328613281
Step: 300, Loss: 1.2790848016738892, Computation time: 0.004530906677246094
Step: 400, Loss: 1.244489073753357, Computation time: 0.004511833190917969
Step: 500, Loss: 1.235985517501831, Computation time: 0.004565000534057617
Step: 600, Loss: 1.202713966369629, Computation time: 0.004781961441040039
Step: 700, Loss: 1.2291572093963623, Computation time: 0.00454401969909668
Step: 800, Loss: 1.1679149866104126, Computation time: 0.0047419071197509766
Step: 900, Loss: 1.2050381898880005, Computation time: 0.004331111907958984
Step: 1000, Loss: 1.1861329078674316, Computation time: 0.0062639713287353516
Step: 1100, Loss: 1.1747440099716187, Computation time: 0.004642963409423828
Step: 1200, Loss: 1.2095255851745605, Computation time: 0.0046939849853515625
Step: 1300, Loss: 1.1605581045150757, Computation time: 0.004503011703491211
Step: 1400, Loss: 1.1921112537384033, Computation time: 0.004410982131958008
Step: 1500, Loss: 1.1909801959991455, Computation time: 0.004683017730712891
Step: 1600, Loss: 1.1826883554458618, Computation time: 0.00436091423034668
Step: 1700, Loss: 1.1849846839904785, Computation time: 0.00456690788269043
Step: 1800, Loss: 1.1776082515716553, Computation time: 0.004513740539550781
Step: 1900, Loss: 1.1725499629974365, Computation time: 0.004731178283691406
Step: 2000, Loss: 1.1645245552062988, Computation time: 0.0044748783111572266
Step: 2100, Loss: 1.1961400508880615, Computation time: 0.004731178283691406
Step: 2200, Loss: 1.2048395872116089, Computation time: 0.004636049270629883
Step: 2300, Loss: 1.1582252979278564, Computation time: 0.004842042922973633
Step: 2400, Loss: 1.1884621381759644, Computation time: 0.0047490596771240234
Step: 2500, Loss: 1.167927622795105, Computation time: 0.004518747329711914
Step: 2600, Loss: 1.1785601377487183, Computation time: 0.0047168731689453125
Step: 2700, Loss: 1.1697403192520142, Computation time: 0.0045888423919677734
Step: 2800, Loss: 1.1548969745635986, Computation time: 0.0052530765533447266
Step: 2900, Loss: 1.1366448402404785, Computation time: 0.004511833190917969
Step: 3000, Loss: 1.1529115438461304, Computation time: 0.007226228713989258
Step: 3100, Loss: 1.1682521104812622, Computation time: 0.007007122039794922
Step: 3200, Loss: 1.1438877582550049, Computation time: 0.004540205001831055
Step: 3300, Loss: 1.167274832725525, Computation time: 0.004748106002807617
Step: 3400, Loss: 1.122734785079956, Computation time: 0.004651069641113281
Step: 3500, Loss: 1.1813410520553589, Computation time: 0.004653215408325195
Step: 3600, Loss: 1.1492958068847656, Computation time: 0.0044820308685302734
Step: 3700, Loss: 1.155427098274231, Computation time: 0.004559040069580078
Step: 3800, Loss: 1.1665260791778564, Computation time: 0.004667043685913086
Step: 3900, Loss: 1.1424567699432373, Computation time: 0.0060329437255859375
Step: 4000, Loss: 1.12287437915802, Computation time: 0.0052568912506103516
Step: 4100, Loss: 1.1657789945602417, Computation time: 0.004556179046630859
Step: 4200, Loss: 1.1385401487350464, Computation time: 0.0044519901275634766
Step: 4300, Loss: 1.1647453308105469, Computation time: 0.004374027252197266
Step: 4400, Loss: 1.1640743017196655, Computation time: 0.0047757625579833984
Step: 4500, Loss: 1.1143943071365356, Computation time: 0.004446268081665039
Step: 4600, Loss: 1.1555577516555786, Computation time: 0.004429817199707031
Step: 4700, Loss: 1.1433415412902832, Computation time: 0.004705905914306641
Step: 4800, Loss: 1.1407968997955322, Computation time: 0.0045168399810791016
Step: 4900, Loss: 1.1176646947860718, Computation time: 0.004727840423583984
Step: 5000, Loss: 1.1334787607192993, Computation time: 0.004560947418212891
Step: 5100, Loss: 1.1539499759674072, Computation time: 0.004340171813964844
Step: 5200, Loss: 1.1314822435379028, Computation time: 0.004489898681640625
Step: 5300, Loss: 1.1245366334915161, Computation time: 0.004488945007324219
Step: 5400, Loss: 1.1406058073043823, Computation time: 0.0045261383056640625
Step: 5500, Loss: 1.1327133178710938, Computation time: 0.004634857177734375
Step: 5600, Loss: 1.1390703916549683, Computation time: 0.00452876091003418
Step: 5700, Loss: 1.1343345642089844, Computation time: 0.0046558380126953125
Step: 5800, Loss: 1.1199959516525269, Computation time: 0.004340171813964844
Step: 5900, Loss: 1.128940224647522, Computation time: 0.004805088043212891
Step: 6000, Loss: 1.1478036642074585, Computation time: 0.0045511722564697266
Step: 6100, Loss: 1.1377071142196655, Computation time: 0.004395961761474609
Step: 6200, Loss: 1.1427260637283325, Computation time: 0.004565715789794922
Step: 6300, Loss: 1.1402515172958374, Computation time: 0.004658937454223633
Step: 6400, Loss: 1.1456530094146729, Computation time: 0.004646778106689453
Step: 6500, Loss: 1.1391595602035522, Computation time: 0.004822969436645508
Step: 6600, Loss: 1.1434180736541748, Computation time: 0.0047607421875
Step: 6700, Loss: 1.1506056785583496, Computation time: 0.004356861114501953
Step: 6800, Loss: 1.110101342201233, Computation time: 0.004297971725463867
Step: 6900, Loss: 1.1358736753463745, Computation time: 0.004586935043334961
Step: 7000, Loss: 1.1711256504058838, Computation time: 0.0043790340423583984
Step: 7100, Loss: 1.1467245817184448, Computation time: 0.004685163497924805
Step: 7200, Loss: 1.1316251754760742, Computation time: 0.0045278072357177734
Step: 7300, Loss: 1.1354844570159912, Computation time: 0.004673004150390625
Step: 7400, Loss: 1.1378566026687622, Computation time: 0.004387855529785156
Step: 7500, Loss: 1.1347078084945679, Computation time: 0.0054781436920166016
Step: 7600, Loss: 1.1254860162734985, Computation time: 0.004527091979980469
Step: 7700, Loss: 1.1208924055099487, Computation time: 0.00446009635925293
Step: 7800, Loss: 1.1302648782730103, Computation time: 0.004626274108886719
Step: 7900, Loss: 1.1369004249572754, Computation time: 0.004567861557006836
Step: 8000, Loss: 1.1482415199279785, Computation time: 0.006008625030517578
Step: 8100, Loss: 1.1277925968170166, Computation time: 0.005083799362182617
Step: 8200, Loss: 1.1214911937713623, Computation time: 0.0046117305755615234
Step: 8300, Loss: 1.1475985050201416, Computation time: 0.004723072052001953
Step: 8400, Loss: 1.1315407752990723, Computation time: 0.005590200424194336
Step: 8500, Loss: 1.137768268585205, Computation time: 0.004975795745849609
Step: 8600, Loss: 1.1586252450942993, Computation time: 0.004456758499145508
Step: 8700, Loss: 1.1627155542373657, Computation time: 0.004718780517578125
Step: 8800, Loss: 1.144771695137024, Computation time: 0.004578113555908203
Step: 8900, Loss: 1.1369574069976807, Computation time: 0.0045278072357177734
Step: 9000, Loss: 1.121504306793213, Computation time: 0.004799365997314453
Step: 9100, Loss: 1.1281462907791138, Computation time: 0.004580974578857422
Step: 9200, Loss: 1.1350762844085693, Computation time: 0.0045320987701416016
Step: 9300, Loss: 1.1381124258041382, Computation time: 0.004789829254150391
Step: 9400, Loss: 1.1473619937896729, Computation time: 0.004591941833496094
Step: 9500, Loss: 1.1407455205917358, Computation time: 0.005048036575317383
Step: 9600, Loss: 1.1148649454116821, Computation time: 0.004637718200683594
Step: 9700, Loss: 1.160093903541565, Computation time: 0.004717111587524414
Step: 9800, Loss: 1.138428807258606, Computation time: 0.0045239925384521484
Step: 9900, Loss: 1.1236419677734375, Computation time: 0.00479888916015625
Step: 10000, Loss: 1.1474863290786743, Computation time: 0.004647016525268555
