Skip to content

Autoparallelism (running on multiple GPUs)ยค

It's very common to have a machine with multiple GPUs, and to seek to parallelise your computation over them.

JAX has a number of advanced APIs to support this. The main technique is to "shard" an array, so that each device holds part of the array.

In this example, we'll parallelise our computation (usually it's a training step) over 8 devices, so that each device gets 1/8 of the batch of data.

First let's import everything, and set up our toy problem.

import equinox as eqx
import jax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.random as jr
import jax.sharding as jshard
import numpy as np
import optax  # https://github.com/deepmind/optax


# Hyperparameters
dataset_size = 64
channel_size = 4
hidden_size = 32
depth = 1
learning_rate = 3e-4
num_steps = 10
batch_size = 16  # must be a multiple of our number of devices.

# Generate some synthetic data
xs = np.random.normal(size=(dataset_size, channel_size))
ys = np.sin(xs)

model = eqx.nn.MLP(channel_size, channel_size, hidden_size, depth, key=jr.PRNGKey(6789))
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))


# Loss function for a batch of data
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)


# Simple dataloader; randomly slices our dataset and shuffles between epochs.
# In NumPy for speed, as our dataset is small enough to fit entirely in host memory.
#
# For larger datasets (that require loading from disk) then use PyTorch's `DataLoader`
# or TensorFlow's `tf.data`.
def train_dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

Okay, now the interesting things start happening!

First, we're going to arrange to "donate" memory, which specifies that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :)

Second, we're going to use eqx.filter_shard to assert (on the inputs) and enforce (on the outputs) how each array is split across each of our devices. As we're doing data parallelism in this example, then we'll be replicating our model parameters on to every device, whilst sharding our data between devices.

@eqx.filter_jit(donate="all")
def train_step(model, opt_state, x, y, sharding):
    replicated = sharding.replicate()
    model, opt_state = eqx.filter_shard((model, opt_state), replicated)
    x, y = eqx.filter_shard((x, y), sharding)

    grads = eqx.filter_grad(compute_loss)(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)

    model, opt_state = eqx.filter_shard((model, opt_state), replicated)

    return model, opt_state

Now the magic: create our sharding object, move our data on to our devices, and run the code.

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()

model = eqx.filter_shard(model, replicated)
for step, (x, y) in zip(
    range(1, num_steps + 1), train_dataloader((xs, ys), batch_size)
):
    x, y = eqx.filter_shard((x, y), sharding)
    model, opt_state = train_step(model, opt_state, x, y, sharding)

Not strictly related to parallelism, but a common question at this point: if we want to evaluate our model, then we probably don't want to donate its parameters (which would render the model unusable, as all its memory is freed). As such, inference looks like this:

def eval_dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    start = 0
    end = batch_size
    while start < dataset_size:
        yield tuple(array[start:end] for array in arrays)
        start = end
        end = start + batch_size


@eqx.filter_jit(donate="all-except-first")
def evaluate(model, x, y, sharding):
    replicated = sharding.replicate()
    model = eqx.filter_shard(model, replicated)
    x, y = eqx.filter_shard((x, y), sharding)
    return compute_loss(model, x, y)


loss = 0
num_batches = 0
for x, y in eval_dataloader((xs, ys), batch_size):
    loss = loss + evaluate(model, x, y, sharding).item()
    num_batches = num_batches + 1
print(f"train loss={loss/num_batches}")

That's it!

Once you've specified how you want to split up your input data, then JAX does the rest of it for you! It takes your single JIT'd computation (which you wrote as if you were targeting a single huge device), and it then automatically determined how to split up that computation and have each device handle part of the computation. This is JAX's computation-follows-data approach to autoparallelism.

If you ran the above example on a computer with multiple NVIDIA GPUs, then you can check whether you're using as many GPUs as you expected by running nvidia-smi from the command line. You can also use jax.debug.visualize_array_sharding(array) to inspect the sharding manually.

What about pmap?

The JAX team have been hard at work introducing these new easy-to-use parallelism features, based around JIT and sharding. These are often faster and more expressive than pmap, so pmap is no longer recommended!

Types of parallelism

There are multiple types of parallelism. In this example we demonstrated data parallelism, in which we parallelise over the data. This is one of the simplest to set up, and often very effective.

For completeness we note that there are other kinds of parallelism available -- e.g. model parallelism, which instead places different parts of the model on different devices. A discussion on those is a more advanced topic.

{jax.device_put, jax.lax.with_sharding_constraint} vs eqx.filter_shard

These are the usual story in Equinox: we have a filtered version of the operation that leaves any non-arrays alone. In this case, they are used because we have an activation function (i.e. just some arbitrary Python function, which isn't an array) as part of the MLP.

Further reading

Equinox works smoothly with all the built-in parallelism APIs provided by JAX. If you want to know more, then the relevant parts of the JAX documentation are: