Skip to content

Serialisation¤

See also the serialisation example for a worked example combining both parameters and hyperparameters together.


equinox.tree_serialise_leaves(path_or_file: Union[str, pathlib.Path, BinaryIO], pytree: PyTree, filter_spec = <function default_serialise_filter_spec>, is_leaf: Optional[Callable[[Any], bool]] = None) -> None ¤

Save the leaves of a PyTree to file.

Arguments:

  • path_or_file: The file location to save values to or a binary file-like object.
  • pytree: The PyTree whose leaves will be saved.
  • filter_spec: Specifies how to save each kind of leaf. By default all JAX arrays, NumPy arrays, Python bool/int/float/complexes are saved, and all other leaf types are ignored. (See equinox.default_serialise_filter_spec.)
  • is_leaf: Called on every node of pytree; if True then this node will be treated as a leaf.

Returns:

Nothing.

Example

This can be used to save a model to file.

import equinox as eqx
import jax.random as jr

model = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model)

Info

filter_spec should typically be a function (File, Any) -> None, which takes a file handle and a leaf to save, and either saves the leaf to the file or does nothing.

It can also be a PyTree of such functions, in which case the PyTree structure should be a prefix of pytree, and each function will be mapped over the corresponding sub-PyTree of pytree.


equinox.tree_deserialise_leaves(path_or_file: Union[str, pathlib.Path, BinaryIO], like: PyTree, filter_spec = <function default_deserialise_filter_spec>, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree ¤

Load the leaves of a PyTree from a file.

Arguments:

  • path_or_file: The file location to load values from or a binary file-like object.
  • like: A PyTree of same structure, and with leaves of the same type, as the PyTree being loaded. Those leaves which are loaded will replace the corresponding leaves of like.
  • filter_spec: Specifies how to load each kind of leaf. By default all JAX arrays, NumPy arrays, Python bool/int/float/complexes are loaded, and all other leaf types are not loaded, and will retain their value from like. (See equinox.default_deserialise_filter_spec.)
  • is_leaf: Called on every node of like; if True then this node will be treated as a leaf.

Returns:

The loaded PyTree, formed by iterating over like and replacing some of its leaves with the leaves saved in path.

Example

This can be used to load a model from file.

import equinox as eqx
import jax.random as jr

model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model_original)
model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)

# To partially load weights, do model surgery. In this case load everything
# except the final layer.
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)

Example

A common pattern is the following:

def run(..., load_path=None):
    if load_path is None:
        model = Model(...hyperparameters...)
    else:
        model = eqx.filter_eval_shape(Model, ...hyperparameters...)
        model = eqx.tree_deserialise_leaves(load_path, model)
in which either a model is created directly (e.g. at the start of training), or a suitable like is constructed (e.g. when resuming training), where equinox.filter_eval_shape is used to avoid creating spurious short-lived arrays taking up memory.

Info

filter_spec should typically be a function (File, Any) -> Any, which takes a file handle and a leaf from like, and either returns the corresponding loaded leaf, or returns the leaf from like unchanged.

It can also be a PyTree of such functions, in which case the PyTree structure should be a prefix of pytree, and each function will be mapped over the corresponding sub-PyTree of pytree.


equinox.default_serialise_filter_spec(f: BinaryIO, x: Any) -> None ¤

Default filter specification for serialising a leaf.

Arguments

  • f: file-like object
  • x: The leaf to be saved on the disk.

Returns

Nothing.

Info

This function can be extended to customise the serialisation behaviour for leaves.

Example

Skipping saving of jax.Array.

import jax.numpy as jnp
import equinox as eqx

tree = (jnp.array([1,2,3]), [4,5,6])
new_filter_spec = lambda f,x: (
    None if isinstance(x, jax.Array) else eqx.default_serialise_filter_spec(f, x)
)
eqx.tree_serialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)

equinox.default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any ¤

Default filter specification for deserialising saved data.

Arguments

  • f: file-like object
  • x: The leaf for which the data needs to be loaded.

Returns

The new value for datatype x.

Info

This function can be extended to customise the deserialisation behaviour for leaves.

Example

Skipping loading of jax.Array.

import jax.numpy as jnp
import equinox as eqx

tree = (jnp.array([4,5,6]), [1,2,3])
new_filter_spec = lambda f,x: (
    x if isinstance(x, jax.Array) else eqx.default_deserialise_filter_spec(f, x)
)
new_tree = eqx.tree_deserialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)