Skip to content

Continuous Normalising Flowยค

This example is a bit of fun! It constructs a continuous normalising flow (CNF) to learn a distribution specified by a (greyscale) image. That is, the target distribution is over \(\mathbb{R}^2\), and the image specifies the (unnormalised) density at each point.

You can specify your own images, and learn your own flows.

Some example outputs from this script:

cat cat cat

Reference:

@article{grathwohl2019ffjord,
    title={{FFJORD}: {F}ree-form {C}ontinuous {D}ynamics for {S}calable {R}eversible
           {G}enerative {M}odels},
    author={Grathwohl, Will and Chen, Ricky T. Q. and Bettencourt, Jesse and
            Sutskever, Ilya and Duvenaud, David},
    journal={International Conference on Learning Representations},
    year={2019},
}

This example is available as a Jupyter notebook here.

Warning

This example will need a GPU to run efficiently.

Advanced example

This is a pretty advanced example.

import math
import os
import pathlib
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import imageio
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import scipy.stats as stats


here = pathlib.Path(os.getcwd())

First let's define some vector fields. This is basically just an MLP, using tanh as the activation function and "ConcatSquash" instead of linear layers.

This is the vector field on the right hand side of the ODE.

class Func(eqx.Module):
    layers: list[eqx.nn.Linear]

    def __init__(self, *, data_size, width_size, depth, key, **kwargs):
        super().__init__(**kwargs)
        keys = jr.split(key, depth + 1)
        layers = []
        if depth == 0:
            layers.append(
                ConcatSquash(in_size=data_size, out_size=data_size, key=keys[0])
            )
        else:
            layers.append(
                ConcatSquash(in_size=data_size, out_size=width_size, key=keys[0])
            )
            for i in range(depth - 1):
                layers.append(
                    ConcatSquash(
                        in_size=width_size, out_size=width_size, key=keys[i + 1]
                    )
                )
            layers.append(
                ConcatSquash(in_size=width_size, out_size=data_size, key=keys[-1])
            )
        self.layers = layers

    def __call__(self, t, y, args):
        t = jnp.asarray(t)[None]
        for layer in self.layers[:-1]:
            y = layer(t, y)
            y = jnn.tanh(y)
        y = self.layers[-1](t, y)
        return y


# Credit: this layer, and some of the default hyperparameters below, are taken from the
# FFJORD repo.
class ConcatSquash(eqx.Module):
    lin1: eqx.nn.Linear
    lin2: eqx.nn.Linear
    lin3: eqx.nn.Linear

    def __init__(self, *, in_size, out_size, key, **kwargs):
        super().__init__(**kwargs)
        key1, key2, key3 = jr.split(key, 3)
        self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)
        self.lin2 = eqx.nn.Linear(1, out_size, key=key2)
        self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)

    def __call__(self, t, y):
        return self.lin1(y) * jnn.sigmoid(self.lin2(t)) + self.lin3(t)

When training, we need to wrap our vector fields in something that also computes the change in log-density.

This can be done either approximately (using Hutchinson's trace estimator) or exactly (the divergence of the vector field; relatively computationally expensive).

def approx_logp_wrapper(t, y, args):
    y, _ = y
    *args, eps, func = args
    fn = lambda y: func(t, y, args)
    f, vjp_fn = jax.vjp(fn, y)
    (eps_dfdy,) = vjp_fn(eps)
    logp = jnp.sum(eps_dfdy * eps)
    return f, logp


def exact_logp_wrapper(t, y, args):
    y, _ = y
    *args, _, func = args
    fn = lambda y: func(t, y, args)
    f, vjp_fn = jax.vjp(fn, y)
    (size,) = y.shape  # this implementation only works for 1D input
    eye = jnp.eye(size)
    (dfdy,) = jax.vmap(vjp_fn)(eye)
    logp = jnp.trace(dfdy)
    return f, logp

Wrap up the differential equation solve into a model.

def normal_log_likelihood(y):
    return -0.5 * (y.size * math.log(2 * math.pi) + jnp.sum(y**2))


class CNF(eqx.Module):
    funcs: list[Func]
    data_size: int
    exact_logp: bool
    t0: float
    t1: float
    dt0: float

    def __init__(
        self,
        *,
        data_size,
        exact_logp,
        num_blocks,
        width_size,
        depth,
        key,
        **kwargs,
    ):
        super().__init__(**kwargs)
        keys = jr.split(key, num_blocks)
        self.funcs = [
            Func(
                data_size=data_size,
                width_size=width_size,
                depth=depth,
                key=k,
            )
            for k in keys
        ]
        self.data_size = data_size
        self.exact_logp = exact_logp
        self.t0 = 0.0
        self.t1 = 0.5
        self.dt0 = 0.05

    # Runs backward-in-time to train the CNF.
    def train(self, y, *, key):
        if self.exact_logp:
            term = diffrax.ODETerm(exact_logp_wrapper)
        else:
            term = diffrax.ODETerm(approx_logp_wrapper)
        solver = diffrax.Tsit5()
        eps = jr.normal(key, y.shape)
        delta_log_likelihood = 0.0
        for func in reversed(self.funcs):
            y = (y, delta_log_likelihood)
            sol = diffrax.diffeqsolve(
                term, solver, self.t1, self.t0, -self.dt0, y, (eps, func)
            )
            (y,), (delta_log_likelihood,) = sol.ys
        return delta_log_likelihood + normal_log_likelihood(y)

    # Runs forward-in-time to draw samples from the CNF.
    def sample(self, *, key):
        y = jr.normal(key, (self.data_size,))
        for func in self.funcs:
            term = diffrax.ODETerm(func)
            solver = diffrax.Tsit5()
            sol = diffrax.diffeqsolve(term, solver, self.t0, self.t1, self.dt0, y)
            (y,) = sol.ys
        return y

    # To make illustrations, we have a variant sample method we can query to see the
    # evolution of the samples during the forward solve.
    def sample_flow(self, *, key):
        t_so_far = self.t0
        t_end = self.t0 + (self.t1 - self.t0) * len(self.funcs)
        save_times = jnp.linspace(self.t0, t_end, 6)
        y = jr.normal(key, (self.data_size,))
        out = []
        for i, func in enumerate(self.funcs):
            if i == len(self.funcs) - 1:
                save_ts = save_times[t_so_far <= save_times] - t_so_far
            else:
                save_ts = (
                    save_times[
                        (t_so_far <= save_times)
                        & (save_times < t_so_far + self.t1 - self.t0)
                    ]
                    - t_so_far
                )
                t_so_far = t_so_far + self.t1 - self.t0
            term = diffrax.ODETerm(func)
            solver = diffrax.Tsit5()
            saveat = diffrax.SaveAt(ts=save_ts)
            sol = diffrax.diffeqsolve(
                term, solver, self.t0, self.t1, self.dt0, y, saveat=saveat
            )
            out.append(sol.ys)
            y = sol.ys[-1]
        out = jnp.concatenate(out)
        assert len(out) == 6  # number of points we saved at
        return out

Alright, that's the models done. Now let's get some data.

First we have a function for taking the specified input image, and turning it into data.

def get_data(path):
    # integer array of shape (height, width, channels) with values in {0, ..., 255}
    img = jnp.asarray(imageio.imread(path))
    if img.shape[-1] == 4:
        img = img[..., :-1]  # ignore alpha channel
    height, width, channels = img.shape
    assert channels == 3
    # Convert to greyscale for simplicity.
    img = img @ jnp.array([0.2989, 0.5870, 0.1140])
    img = jnp.transpose(img)[:, ::-1]  # (width, height)
    x = jnp.arange(width, dtype=jnp.float32)
    y = jnp.arange(height, dtype=jnp.float32)
    x, y = jnp.broadcast_arrays(x[:, None], y[None, :])
    weights = 1 - img.reshape(-1).astype(jnp.float32) / jnp.max(img)
    dataset = jnp.stack(
        [x.reshape(-1), y.reshape(-1)], axis=-1
    )  # shape (dataset_size, 2)
    # For efficiency we don't bother with the particles that will have weight zero.
    cond = img.reshape(-1) < 254
    dataset = dataset[cond]
    weights = weights[cond]
    mean = jnp.mean(dataset, axis=0)
    std = jnp.std(dataset, axis=0) + 1e-6
    dataset = (dataset - mean) / std

    return dataset, weights, mean, std, img, width, height

Now to load the data during training, we need a dataloader. In this case our dataset is small enough to fit in-memory, so we use a dataloader implementation that we can include within our overall JIT wrapper, for speed.

class DataLoader(eqx.Module):
    arrays: tuple[jnp.ndarray, ...]
    batch_size: int
    key: jr.PRNGKey

    def __check_init__(self):
        dataset_size = self.arrays[0].shape[0]
        assert all(array.shape[0] == dataset_size for array in self.arrays)

    def __call__(self, step):
        dataset_size = self.arrays[0].shape[0]
        num_batches = dataset_size // self.batch_size
        epoch = step // num_batches
        key = jr.fold_in(self.key, epoch)
        perm = jr.permutation(key, jnp.arange(dataset_size))
        start = (step % num_batches) * self.batch_size
        slice_size = self.batch_size
        batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)
        return tuple(array[batch_indices] for array in self.arrays)

Bring everything together. This function is our entry point.

def main(
    in_path,
    out_path=None,
    batch_size=500,
    virtual_batches=2,
    lr=1e-3,
    weight_decay=1e-5,
    steps=10000,
    exact_logp=True,
    num_blocks=2,
    width_size=64,
    depth=3,
    print_every=100,
    seed=5678,
):
    if out_path is None:
        out_path = here / pathlib.Path(in_path).name
    else:
        out_path = pathlib.Path(out_path)

    key = jr.PRNGKey(seed)
    model_key, loader_key, loss_key, sample_key = jr.split(key, 4)

    dataset, weights, mean, std, img, width, height = get_data(in_path)
    dataset_size, data_size = dataset.shape
    dataloader = DataLoader((dataset, weights), batch_size, key=loader_key)

    model = CNF(
        data_size=data_size,
        exact_logp=exact_logp,
        num_blocks=num_blocks,
        width_size=width_size,
        depth=depth,
        key=model_key,
    )

    optim = optax.adamw(lr, weight_decay=weight_decay)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_value_and_grad
    def loss(model, data, weight, loss_key):
        batch_size, _ = data.shape
        noise_key, train_key = jr.split(loss_key, 2)
        train_key = jr.split(key, batch_size)
        data = data + jr.normal(noise_key, data.shape) * 0.5 / std
        log_likelihood = jax.vmap(model.train)(data, key=train_key)
        return -jnp.mean(weight * log_likelihood)  # minimise negative log-likelihood

    @eqx.filter_jit
    def make_step(model, opt_state, step, loss_key):
        # We only need gradients with respect to floating point JAX arrays, not any
        # other part of our model. (e.g. the `exact_logp` flag. What would it even mean
        # to differentiate that? Note that `eqx.filter_value_and_grad` does the same
        # filtering by `eqx.is_inexact_array` by default.)
        value = 0
        grads = jax.tree_util.tree_map(
            lambda leaf: jnp.zeros_like(leaf) if eqx.is_inexact_array(leaf) else None,
            model,
        )

        # Get more accurate gradients by accumulating gradients over multiple batches.
        # (Or equivalently, get lower memory requirements by splitting up a batch over
        # multiple steps.)
        def make_virtual_step(_, state):
            value, grads, step, loss_key = state
            data, weight = dataloader(step)
            value_, grads_ = loss(model, data, weight, loss_key)
            value = value + value_
            grads = jax.tree_util.tree_map(lambda a, b: a + b, grads, grads_)
            step = step + 1
            loss_key = jr.split(loss_key, 1)[0]
            return value, grads, step, loss_key

        value, grads, step, loss_key = lax.fori_loop(
            0, virtual_batches, make_virtual_step, (value, grads, step, loss_key)
        )
        value = value / virtual_batches
        grads = jax.tree_util.tree_map(lambda a: a / virtual_batches, grads)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state, step, loss_key

    step = 0
    while step < steps:
        start = time.time()
        value, model, opt_state, step, loss_key = make_step(
            model, opt_state, step, loss_key
        )
        end = time.time()
        if (step % print_every) == 0 or step == steps - 1:
            print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")

    num_samples = 5000
    sample_key = jr.split(sample_key, num_samples)
    samples = jax.vmap(model.sample)(key=sample_key)
    sample_flows = jax.vmap(model.sample_flow, out_axes=-1)(key=sample_key)
    fig, (*axs, ax, axtrue) = plt.subplots(
        1,
        2 + len(sample_flows),
        figsize=((2 + len(sample_flows)) * 10 * height / width, 10),
    )

    samples = samples * std + mean
    x = samples[:, 0]
    y = samples[:, 1]
    ax.scatter(x, y, c="black", s=2)
    ax.set_xlim(-0.5, width - 0.5)
    ax.set_ylim(-0.5, height - 0.5)
    ax.set_aspect(height / width)
    ax.set_xticks([])
    ax.set_yticks([])

    axtrue.imshow(img.T, origin="lower", cmap="gray")
    axtrue.set_aspect(height / width)
    axtrue.set_xticks([])
    axtrue.set_yticks([])

    x_resolution = 100
    y_resolution = int(x_resolution * (height / width))
    sample_flows = sample_flows * std[:, None] + mean[:, None]
    x_pos, y_pos = jnp.broadcast_arrays(
        jnp.linspace(-1, width + 1, x_resolution)[:, None],
        jnp.linspace(-1, height + 1, y_resolution)[None, :],
    )
    positions = jnp.stack([jnp.ravel(x_pos), jnp.ravel(y_pos)])
    densities = [stats.gaussian_kde(samples)(positions) for samples in sample_flows]
    for i, (ax, density) in enumerate(zip(axs, densities)):
        density = jnp.reshape(density, (x_resolution, y_resolution))
        ax.imshow(density.T, origin="lower", cmap="plasma")
        ax.set_aspect(height / width)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.savefig(out_path)
    plt.show()

And now the following commands will reproduce the images displayed at the start.

main(in_path="../imgs/cat.png")
main(in_path="../imgs/butterfly.png", num_blocks=3)
main(in_path="../imgs/target.png", width_size=128)

Let's run the first one of those as a demonstration.

main(in_path="../imgs/cat.png")
Step: 100, Loss: 1.3118890523910522, Computation time: 1.0344874858856201
Step: 200, Loss: 1.251081943511963, Computation time: 1.0480339527130127
Step: 300, Loss: 1.2637590169906616, Computation time: 1.053438425064087
Step: 400, Loss: 1.238480567932129, Computation time: 1.1258141994476318
Step: 500, Loss: 1.2385149002075195, Computation time: 1.1150507926940918
Step: 600, Loss: 1.1769757270812988, Computation time: 1.0526695251464844
Step: 700, Loss: 1.2026259899139404, Computation time: 1.0425992012023926
Step: 800, Loss: 1.217708706855774, Computation time: 1.0489792823791504
Step: 900, Loss: 1.227858543395996, Computation time: 1.0484890937805176
Step: 1000, Loss: 1.189058780670166, Computation time: 1.0530979633331299
Step: 1100, Loss: 1.1783877611160278, Computation time: 1.0550987720489502
Step: 1200, Loss: 1.150791049003601, Computation time: 1.0446579456329346
Step: 1300, Loss: 1.221590518951416, Computation time: 1.0396397113800049
Step: 1400, Loss: 1.1535735130310059, Computation time: 1.0444588661193848
Step: 1500, Loss: 1.2072296142578125, Computation time: 1.0372049808502197
Step: 1600, Loss: 1.1677805185317993, Computation time: 1.0316574573516846
Step: 1700, Loss: 1.1480119228363037, Computation time: 1.037048101425171
Step: 1800, Loss: 1.1469142436981201, Computation time: 1.0355620384216309
Step: 1900, Loss: 1.1540181636810303, Computation time: 1.0425419807434082
Step: 2000, Loss: 1.170578122138977, Computation time: 1.034224271774292
Step: 2100, Loss: 1.1531469821929932, Computation time: 1.110149621963501
Step: 2200, Loss: 1.1332671642303467, Computation time: 1.1052398681640625
Step: 2300, Loss: 1.1843323707580566, Computation time: 1.0217394828796387
Step: 2400, Loss: 1.1606663465499878, Computation time: 1.0373609066009521
Step: 2500, Loss: 1.132057785987854, Computation time: 1.029083013534546
Step: 2600, Loss: 1.1429660320281982, Computation time: 1.0214695930480957
Step: 2700, Loss: 1.152261734008789, Computation time: 1.041821002960205
Step: 2800, Loss: 1.1637940406799316, Computation time: 1.023808240890503
Step: 2900, Loss: 1.1682878732681274, Computation time: 1.0261313915252686
Step: 3000, Loss: 1.1528184413909912, Computation time: 1.0209197998046875
Step: 3100, Loss: 1.1718814373016357, Computation time: 1.0247409343719482
Step: 3200, Loss: 1.1433460712432861, Computation time: 1.0423102378845215
Step: 3300, Loss: 1.166672706604004, Computation time: 1.025174856185913
Step: 3400, Loss: 1.1842900514602661, Computation time: 1.0576817989349365
Step: 3500, Loss: 1.1458779573440552, Computation time: 1.043668270111084
Step: 3600, Loss: 1.158961296081543, Computation time: 1.0318820476531982
Step: 3700, Loss: 1.157321810722351, Computation time: 1.042715072631836
Step: 3800, Loss: 1.1473262310028076, Computation time: 1.0452227592468262
Step: 3900, Loss: 1.1316838264465332, Computation time: 1.0491411685943604
Step: 4000, Loss: 1.149780035018921, Computation time: 1.0549867153167725
Step: 4100, Loss: 1.140995740890503, Computation time: 1.0537455081939697
Step: 4200, Loss: 1.1636855602264404, Computation time: 1.0538051128387451
Step: 4300, Loss: 1.1459300518035889, Computation time: 1.050825834274292
Step: 4400, Loss: 1.1132702827453613, Computation time: 1.0565142631530762
Step: 4500, Loss: 1.1445338726043701, Computation time: 1.0447840690612793
Step: 4600, Loss: 1.185341477394104, Computation time: 1.0458967685699463
Step: 4700, Loss: 1.1719266176223755, Computation time: 1.0501148700714111
Step: 4800, Loss: 1.1642041206359863, Computation time: 1.0497636795043945
Step: 4900, Loss: 1.1277761459350586, Computation time: 1.0412952899932861
Step: 5000, Loss: 1.1528496742248535, Computation time: 1.0490319728851318
Step: 5100, Loss: 1.1376690864562988, Computation time: 1.0476291179656982
Step: 5200, Loss: 1.1241021156311035, Computation time: 1.0485484600067139
Step: 5300, Loss: 1.186220645904541, Computation time: 1.054675817489624
Step: 5400, Loss: 1.1623786687850952, Computation time: 1.0538649559020996
Step: 5500, Loss: 1.156112551689148, Computation time: 1.0428664684295654
Step: 5600, Loss: 1.1495476961135864, Computation time: 1.0700688362121582
Step: 5700, Loss: 1.1711654663085938, Computation time: 1.056363821029663
Step: 5800, Loss: 1.15413236618042, Computation time: 1.0482234954833984
Step: 5900, Loss: 1.13923180103302, Computation time: 1.03656005859375
Step: 6000, Loss: 1.1366792917251587, Computation time: 1.0534214973449707
Step: 6100, Loss: 1.1102700233459473, Computation time: 1.0474035739898682
Step: 6200, Loss: 1.1211810111999512, Computation time: 1.053070068359375
Step: 6300, Loss: 1.1677824258804321, Computation time: 1.0389189720153809
Step: 6400, Loss: 1.152880072593689, Computation time: 1.0483548641204834
Step: 6500, Loss: 1.1546416282653809, Computation time: 1.073331594467163
Step: 6600, Loss: 1.1045620441436768, Computation time: 1.0458285808563232
Step: 6700, Loss: 1.1473908424377441, Computation time: 1.0490500926971436
Step: 6800, Loss: 1.1420358419418335, Computation time: 1.0393445491790771
Step: 6900, Loss: 1.1263903379440308, Computation time: 1.0471012592315674
Step: 7000, Loss: 1.1546120643615723, Computation time: 1.120964765548706
Step: 7100, Loss: 1.1234941482543945, Computation time: 1.1098055839538574
Step: 7200, Loss: 1.1735379695892334, Computation time: 1.112914800643921
Step: 7300, Loss: 1.1549310684204102, Computation time: 1.1177945137023926
Step: 7400, Loss: 1.1674282550811768, Computation time: 1.0430011749267578
Step: 7500, Loss: 1.1249209642410278, Computation time: 1.0475146770477295
Step: 7600, Loss: 1.149993896484375, Computation time: 1.0517895221710205
Step: 7700, Loss: 1.126546859741211, Computation time: 1.0522327423095703
Step: 7800, Loss: 1.0991778373718262, Computation time: 1.034656286239624
Step: 7900, Loss: 1.1408506631851196, Computation time: 1.0303537845611572
Step: 8000, Loss: 1.1567769050598145, Computation time: 1.0465099811553955
Step: 8100, Loss: 1.1207897663116455, Computation time: 1.0353443622589111
Step: 8200, Loss: 1.1423345804214478, Computation time: 1.0374336242675781
Step: 8300, Loss: 1.1438696384429932, Computation time: 1.0480520725250244
Step: 8400, Loss: 1.184295654296875, Computation time: 1.046989917755127
Step: 8500, Loss: 1.1350500583648682, Computation time: 1.0438358783721924
Step: 8600, Loss: 1.1440303325653076, Computation time: 1.0357015132904053
Step: 8700, Loss: 1.1465599536895752, Computation time: 1.0545375347137451
Step: 8800, Loss: 1.116767406463623, Computation time: 1.0341439247131348
Step: 8900, Loss: 1.1343653202056885, Computation time: 1.0588762760162354
Step: 9000, Loss: 1.1371710300445557, Computation time: 1.0479660034179688
Step: 9100, Loss: 1.125238299369812, Computation time: 1.0390236377716064
Step: 9200, Loss: 1.1323764324188232, Computation time: 1.0565376281738281
Step: 9300, Loss: 1.1228868961334229, Computation time: 1.0413126945495605
Step: 9400, Loss: 1.1537754535675049, Computation time: 1.0376060009002686
Step: 9500, Loss: 1.157130241394043, Computation time: 1.0577881336212158
Step: 9600, Loss: 1.1420390605926514, Computation time: 1.058633804321289
Step: 9700, Loss: 1.14859139919281, Computation time: 1.0608363151550293
Step: 9800, Loss: 1.0945243835449219, Computation time: 1.0353436470031738
Step: 9900, Loss: 1.1538352966308594, Computation time: 1.0435452461242676
Step: 10000, Loss: 1.1333094835281372, Computation time: 1.0515074729919434