Skip to content

Freezing parametersยค

In this example, we demonstrate how the filtering of equinox.filter_value_and_grad can be customised -- in this case, to only train some parameters and freeze the rest.

import functools as ft

import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu
import optax  # https://github.com/deepmind/optax

import equinox as eqx
# Toy data
def get_data(dataset_size, *, key):
    x = jrandom.normal(key, (dataset_size, 1))
    y = 5 * x - 2
    return x, y


# Toy dataloader
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

Here, we:

  1. Set up a model. In this case, an MLP.
  2. Set up a filter_spec. This will be a PyTree of the same structure as the model, with False on every leaf -- except for the leaves corresponding to the final layer, which we set to True.
  3. Specify how to make a step. In this case we'll specify that we're still going to JIT with respect to every array, but we're only going to differentiate the ones specified by filter_spec.
def main(
    dataset_size=10000,
    batch_size=256,
    learning_rate=3e-3,
    steps=1000,
    width_size=8,
    depth=1,
    seed=5678,
):
    data_key, loader_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 3)
    data = get_data(dataset_size, key=data_key)
    data_iter = dataloader(data, batch_size, key=loader_key)

    # Step 1
    model = eqx.nn.MLP(
        in_size=1, out_size=1, width_size=width_size, depth=depth, key=model_key
    )

    # Step 2
    filter_spec = jtu.tree_map(lambda _: False, model)
    filter_spec = eqx.tree_at(
        lambda tree: (tree.layers[-1].weight, tree.layers[-1].bias),
        filter_spec,
        replace=(True, True),
    )

    # Step 3
    @eqx.filter_jit
    @ft.partial(eqx.filter_value_and_grad, arg=filter_spec)
    def make_step(model, x, y):
        pred_y = jax.vmap(model)(x)
        return jnp.mean((y - pred_y) ** 2)

    # And now let's train for a short while -- in exactly the usual way -- and see what
    # happens. We keep the original model around to compare to later.
    original_model = model
    optim = optax.sgd(learning_rate)
    opt_state = optim.init(model)
    for step, (x, y) in zip(range(steps), data_iter):
        value, grads = make_step(model, x, y)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
    print(
        f"Parameters of first layer at initialisation:\n{jtu.tree_leaves(original_model.layers[0])}\n"
    )
    print(
        f"Parameters of first layer at end of training:\n{jtu.tree_leaves(model.layers[0])}\n"
    )
    print(
        f"Parameters of last layer at initialisation:\n{jtu.tree_leaves(original_model.layers[-1])}\n"
    )
    print(
        f"Parameters of last layer at end of training:\n{jtu.tree_leaves(model.layers[-1])}\n"
    )

As we'll see, the parameters of the first layer remain unchanged throughout training. Just the parameters of the last layer are trained.

main()
Parameters of first layer at initialisation:
[DeviceArray([[-0.5500405 ],
             [ 0.67074966],
             [-0.9094155 ],
             [-0.5518596 ],
             [-0.1648488 ],
             [ 0.98241615],
             [-0.9118581 ],
             [ 0.32483125]], dtype=float32), DeviceArray([ 0.8876705 ,  0.4363706 , -0.878813  ,  0.26387787,
             -0.68248963, -0.9517925 , -0.21384668, -0.2857628 ],            dtype=float32)]

Parameters of first layer at end of training:
[DeviceArray([[-0.5500405 ],
             [ 0.67074966],
             [-0.9094155 ],
             [-0.5518596 ],
             [-0.1648488 ],
             [ 0.98241615],
             [-0.9118581 ],
             [ 0.32483125]], dtype=float32), DeviceArray([ 0.8876705 ,  0.4363706 , -0.878813  ,  0.26387787,
             -0.68248963, -0.9517925 , -0.21384668, -0.2857628 ],            dtype=float32)]

Parameters of last layer at initialisation:
[DeviceArray([[ 0.33031464,  0.16732198,  0.04151077, -0.01495699,
              -0.00766617,  0.08186949,  0.33581698,  0.13139524]],            dtype=float32), DeviceArray([-0.05869024], dtype=float32)]

Parameters of last layer at end of training:
[DeviceArray([[-2.8348157 ,  3.3568215 , -0.76763505, -2.1319082 ,
              -0.00766617,  1.3563706 , -1.6294043 ,  0.6064026 ]],            dtype=float32), DeviceArray([-0.10058278], dtype=float32)]