Skip to content

Serialising both weights and hyperparametersยค

Equinox has facilities for the serialisation of the leaves of arbitrary PyTrees. The most basic use is to call eqx.tree_serialise_leaves(filename, model) to write all weights to a file. Deserialisation requires a PyTree of the correct shape to serve as a "skeleton" of sorts, whose weights are then read from the file with model = eqx.tree_deserialise_leaves(filename, skeleton).

However, a typical model has both weights (arrays stored as leaves in the PyTree) and hyperparameters (the size of the network, etc.). When deserialising, we would like to read the hyperparameters as well as the weights. Ideally they should be stored in the same file. We can accomplish this as follows.

Let's import everything and set up a simple model:

import json

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr


def make(*, key, size, width, depth, use_tanh=False):
    if use_tanh:
        activation = jnp.tanh
    else:
        activation = jax.nn.relu
    # (This is not meant to be a realistically useful model.)
    return eqx.nn.MLP(
        in_size=size,
        out_size=1,
        width_size=width,
        depth=depth,
        activation=activation,
        key=key,
    )


hyperparameters = {"size": 5, "width": 10, "depth": 3, "use_tanh": True}
model = make(key=jr.PRNGKey(0), **hyperparameters)

At this point, we haven't just created a model, but defined a function that allows us to re-create a model of the same structure. Additionally, the hyperparameters used to create model have been saved for later serialisation.

We may now train the model as usual. When the time comes to serialise, we want to put both hyperparameters and leaves in the same file. This is accomplished like so:

def save(filename, hyperparams, model):
    with open(filename, "wb") as f:
        hyperparam_str = json.dumps(hyperparams)
        f.write((hyperparam_str + "\n").encode())
        eqx.tree_serialise_leaves(f, model)


save("model.eqx", hyperparameters, model)

We've been a bit slick here. A single file now contains a valid json expression storing the hyperparameters and, after a newline, the bytes serialising the weights in our model. Implicitly we're relying on the fact that python's built-in json serialisation places everything on a single line.

With the hyperparameters and model serialised in this way, deserialisation occurs in three steps: 1. Read the first line from the file, and parse the json into a dictionary. 2. Construct a skeleton model using make(...). 3. Have Equinox deserialise the remainder of the file, using the skeleton.

def load(filename):
    with open(filename, "rb") as f:
        hyperparams = json.loads(f.readline().decode())
        model = make(key=jr.PRNGKey(0), **hyperparams)
        return eqx.tree_deserialise_leaves(f, model)


newmodel = load("model.eqx")

# Check that it's loaded correctly:
assert model.layers[1].weight[2, 2] == newmodel.layers[1].weight[2, 2]

Your hyperparameters must be serialisable

Note that the hyperparameters themselves be serialisable (in the above example, as JSON). Fortunately, this is pretty typical.

Why not pickle?

The pickle module is python's go-to for all-purpose serialisation. Why didn't we just use that?

First, Equinox and JAX make rich use of unpickleable quantities, e.g. lambda expressions. This can potentially be cured by using another package (e.g. dill or cloudpickle).

Second, pickling is infamously pretty fraught with edge cases. If we can take a more structured approach (as above), then it's nicer to do so. (For example, the pickle format changes from time to time.)

Third, unpickling loads and runs arbitrary code. When you download an open-source model, this can be a serious security concern! We all expect a file named *.py to be potentially malicious, but a file ostensibly containing "just floating-point numbers" should be safe to use from an untrusted source. The methods described above allow weights to be shared safely as long as the underlying model code is trusted; the serialised file really is interpreted as just an array of numbers.

Other notes

  • Many variations are possible. For example, Equinox serialisation doesn't have to write to a file. You can write to any compatible buffer, e.g. an io.BytesIO object.
  • If you serialise/deserialise between training and inference, and you are using equinox.nn.BatchNorm, equinox.nn.Dropout etc., then make sure to set your desired inference flag when loading. Perhaps make this a required argument to make(...).