Skip to content

Serialisation¤

equinox.tree_serialise_leaves(path: Union[str, pathlib.Path], pytree: PyTree, filter_spec = <function default_serialise_filter_spec>, is_leaf: Callable[[Any], bool] = <function _is_index>) -> None ¤

Save the leaves of a PyTree to file.

Arguments:

  • path: The file location to save values to.
  • pytree: The PyTree whose leaves will be saved.
  • filter_spec: Specifies how to save each kind of leaf. By default all JAX arrays, NumPy arrays, Python bool/int/float/complexes are saved, equinox.experimental.StateIndex instances have their value looked up and saved, and all other leaf types are ignored. (See equinox.default_serialise_filter_spec.)
  • is_leaf: Called on every node of pytree; if True then this node will be treated as a leaf.

Returns:

Nothing.

Example

This can be used to save a model to file.

import equinox as eqx
import jax.random as jr

model = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model)

Info

filter_spec should typically be a function (File, Any) -> None, which takes a file handle and a leaf to save, and either saves the leaf to the file or does nothing.

It can also be a PyTree of such functions, in which case the PyTree structure should be a prefix of pytree, and each function will be mapped over the corresponding sub-PyTree of pytree.


equinox.tree_deserialise_leaves(path: Union[str, pathlib.Path], like: PyTree, filter_spec = <function default_deserialise_filter_spec>, is_leaf: Callable[[Any], bool] = <function _is_index>) -> PyTree ¤

Load the leaves of a PyTree from a file.

Arguments:

  • path: The file location to load values from.
  • like: A PyTree of same structure, and with leaves of the same type, as the PyTree being loaded. Those leaves which are loaded will replace the corresponding leaves of like.
  • filter_spec: Specifies how to load each kind of leaf. By default all JAX arrays, NumPy arrays, Python bool/int/float/complexes are loaded, and equinox.experimental.StateIndex instances have their value looked up and stored, and all other leaf types are not loaded, and will retain their value from like. (See equinox.default_deserialise_filter_spec.)
  • is_leaf: Called on every node of like; if True then this node will be treated as a leaf.

Returns:

The loaded PyTree, formed by iterating over like and replacing some of its leaves with the leaves saved in path.

Example

This can be used to load a model from file.

import equinox as eqx
import jax.random as jr

model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model_original)
model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)

# To partially load weights: in this case load everything except the final layer.
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)

Info

filter_spec should typically be a function (File, Any) -> Any, which takes a file handle and a leaf from like, and either returns the corresponding loaded leaf, or retuns the leaf from like unchanged.

It can also be a PyTree of such functions, in which case the PyTree structure should be a prefix of pytree, and each function will be mapped over the corresponding sub-PyTree of pytree.


equinox.default_serialise_filter_spec(f: BinaryIO, x: Any) -> None ¤

Default filter specification for serialising a leaf.

Arguments

  • f: file-like object
  • x: The leaf to be saved on the disk.

Returns

Nothing.

Info

This function can be extended to customise the serialisation behaviour for leaves.

Example

Skipping saving of jnp.ndarray.

import jax.numpy as jnp
import equinox as eqx

tree = (jnp.array([1,2,3]), [4,5,6])
new_filter_spec = lambda f,x: (
    None if isinstance(x, jnp.ndarray) else eqx.default_serialise_filter_spec(f, x)
)
eqx.tree_serialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)

equinox.default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any ¤

Default filter specification for deserialising saved data.

Arguments

  • f: file-like object
  • x: The leaf for which the data needs to be loaded.

Returns

The new value for datatype x.

Info

This function can be extended to customise the deserialisation behaviour for leaves.

Example

Skipping loading of jnp.ndarray.

import jax.numpy as jnp
import equinox as eqx

tree = (jnp.array([4,5,6]), [1,2,3])
new_filter_spec = lambda f,x: (
    x if isinstance(x, jnp.ndarray) else eqx.default_deserialise_filter_spec(f, x)
)
new_tree = eqx.tree_deserialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)