Skip to content

Freezing parametersยค

In this example, we demonstrate how to only train some parameters and freeze the rest.

This example is available as a Jupyter notebook here.

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu
import optax  #
# Toy data
def get_data(dataset_size, *, key):
    x = jrandom.normal(key, (dataset_size, 1))
    y = 5 * x - 2
    return x, y

# Toy dataloader
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

Here, we:

  1. Set up a model. In this case, an MLP.
  2. Set up a filter_spec. This will be a PyTree of the same structure as the model, with False on every leaf -- except for the leaves corresponding to the final layer, which we set to True.
  3. Specify how to make a step. We'll separate out the leaves we want to differentiate from the leaves that we want to leave alone by using equinox.partition.
def main(
    data_key, loader_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 3)
    data = get_data(dataset_size, key=data_key)
    data_iter = dataloader(data, batch_size, key=loader_key)

    # Step 1
    model = eqx.nn.MLP(
        in_size=1, out_size=1, width_size=width_size, depth=depth, key=model_key

    # Step 2
    filter_spec = jtu.tree_map(lambda _: False, model)
    filter_spec = eqx.tree_at(
        lambda tree: (tree.layers[-1].weight, tree.layers[-1].bias),
        replace=(True, True),

    # Step 3
    def make_step(model, x, y, opt_state):
        def loss(diff_model, static_model, x, y):
            model = eqx.combine(diff_model, static_model)
            pred_y = jax.vmap(model)(x)
            return jnp.mean((y - pred_y) ** 2)

        diff_model, static_model = eqx.partition(model, filter_spec)
        grads = loss(diff_model, static_model, x, y)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state

    # And now let's train for a short while -- in exactly the usual way -- and see what
    # happens. We keep the original model around to compare to later.
    original_model = model
    optim = optax.sgd(learning_rate)
    opt_state = optim.init(model)
    for step, (x, y) in zip(range(steps), data_iter):
        model, opt_state = make_step(model, x, y, opt_state)
        f"Parameters of first layer at initialisation:\n"
        f"Parameters of first layer at end of training:\n"
        f"Parameters of last layer at initialisation:\n"
        f"Parameters of last layer at end of training:\n"

As we'll see, the parameters of the first layer remain unchanged throughout training. Just the parameters of the last layer are trained.

Parameters of first layer at initialisation:
[DeviceArray([[-0.5500405 ],
             [ 0.67074966],
             [-0.9094155 ],
             [-0.5518596 ],
             [-0.1648488 ],
             [ 0.98241615],
             [-0.9118581 ],
             [ 0.32483125]], dtype=float32), DeviceArray([ 0.8876705 ,  0.4363706 , -0.878813  ,  0.26387787,
             -0.68248963, -0.9517925 , -0.21384668, -0.2857628 ],            dtype=float32)]

Parameters of first layer at end of training:
[DeviceArray([[-0.5500405 ],
             [ 0.67074966],
             [-0.9094155 ],
             [-0.5518596 ],
             [-0.1648488 ],
             [ 0.98241615],
             [-0.9118581 ],
             [ 0.32483125]], dtype=float32), DeviceArray([ 0.8876705 ,  0.4363706 , -0.878813  ,  0.26387787,
             -0.68248963, -0.9517925 , -0.21384668, -0.2857628 ],            dtype=float32)]

Parameters of last layer at initialisation:
[DeviceArray([[ 0.33031464,  0.16732198,  0.04151077, -0.01495699,
              -0.00766617,  0.08186949,  0.33581698,  0.13139524]],            dtype=float32), DeviceArray([-0.05869024], dtype=float32)]

Parameters of last layer at end of training:
[DeviceArray([[-2.8348155 ,  3.3568215 , -0.767634  , -2.131908  ,
              -0.00766617,  1.3563743 , -1.6294041 ,  0.60640407]],            dtype=float32), DeviceArray([-0.10058348], dtype=float32)]