Skip to content

Autoparallelism (running on multiple GPUs)ยค

It's very common to have a machine with multiple GPUs, and to seek to parallelise your computation over them.

JAX has a number of advanced APIs to support this. The main technique is to "shard" an array, so that each device holds part of the array.

In this example, we'll parallelise our computation (usually it's a training step) over 8 devices, so that each device gets 1/8 of the batch of data.

First let's import everything, and set up our toy problem.

import equinox as eqx
import jax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.random as jr
import jax.sharding as sharding
import numpy as np
import optax  #

# Hyperparameters
dataset_size = 64
channel_size = 4
hidden_size = 32
depth = 1
learning_rate = 3e-4
num_steps = 10
batch_size = 16  # must be a multiple of our number of devices.

# Generate some synthetic data
xs = np.random.normal(size=(dataset_size, channel_size))
ys = np.sin(xs)

model = eqx.nn.MLP(channel_size, channel_size, hidden_size, depth, key=jr.PRNGKey(6789))
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

def make_step(model, opt_state, x, y):
    grads = eqx.filter_grad(compute_loss)(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state

Here's a very simple dataloader, that randomly shuffles and slices our dataset. We keep everything in pure-NumPy for speed, as this all happens on the host, prior to moving our data to our devices. (Which will often be a cluster of GPUs.)

In practice it's also common to load data using either PyTorch's DataLoader or TensorFlow's API; see here for more details.

def dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

Now the magic: create our sharding object, move our data on to our devices, and run the code.

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
shard = sharding.PositionalSharding(devices)

for step, (x, y) in zip(range(num_steps), dataloader((xs, ys), batch_size)):
    x, y = jax.device_put((x, y), shard)
    model, opt_state = make_step(model, opt_state, x, y)

That's it!

Once you've specified how you want to split up your input data, then JAX does the rest of it for you! It takes your single JIT'd computation (which you wrote as if you were targeting a single huge device), and it then automatically determined how to split up that computation and have each device handle part of the computation. This is JAX's computation-follows-data approach to autoparallelism.

If you ran the above example on a cluster of NVIDIA GPUs, then you can check whether you're using as many GPUs as you expected by running nvidia-smi from the command line. You can also use jax.debug.visualize_array_sharding(array) to inspect the sharding manually.

One possible optimisation here is to re-use the memory used by the input arrays, to store the output arrays. This often improves speed a little bit. This is disabled by default, but can be enabled by passing eqx.filter_jit(donate="all").

What about pmap?

The JAX team have been hard at work introducing these new easy-to-use parallelism features, based around JIT and sharding. These are often faster and more expressive than pmap, so pmap is no longer recommended!

Types of parallelism

There are multiple types of parallelism. In this example we demonstrated data parallelism, in which we parallelise over the data. This is one of the simplest to set up, and often very effective.

For completeness we note that there are other kinds of parallelism available -- e.g. model parallelism, which instead places different parts of the model on different devices. A discussion on those is a more advanced topic. :)

Further reading

Equinox works smoothly with all the built-in parallelism APIs provided by JAX. If you want to know more, then the relevant parts of the JAX documentation are: