Serialisation¤
See also the serialisation example for a worked example combining both parameters and hyperparameters together.
equinox.tree_serialise_leaves(path_or_file: Union[str, pathlib.Path, BinaryIO], pytree: PyTree, filter_spec = <function default_serialise_filter_spec>, is_leaf: Optional[Callable[[Any], bool]] = None) -> None
¤
Save the leaves of a PyTree to file.
Arguments:
path_or_file
: The file location to save values to or a binary file-like object.pytree
: The PyTree whose leaves will be saved.filter_spec
: Specifies how to save each kind of leaf. By default all JAX arrays, NumPy arrays, Python bool/int/float/complexes are saved, and all other leaf types are ignored. (Seeequinox.default_serialise_filter_spec
.)is_leaf
: Called on every node ofpytree
; ifTrue
then this node will be treated as a leaf.
Returns:
Nothing.
Example
This can be used to save a model to file.
import equinox as eqx
import jax.random as jr
model = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model)
Info
filter_spec
should typically be a function (File, Any) -> None
, which takes
a file handle and a leaf to save, and either saves the leaf to the file or
does nothing.
It can also be a PyTree of such functions, in which case the PyTree structure
should be a prefix of pytree
, and each function will be mapped over the
corresponding sub-PyTree of pytree
.
equinox.tree_deserialise_leaves(path_or_file: Union[str, pathlib.Path, BinaryIO], like: PyTree, filter_spec = <function default_deserialise_filter_spec>, is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree
¤
Load the leaves of a PyTree from a file.
Arguments:
path_or_file
: The file location to load values from or a binary file-like object.like
: A PyTree of same structure, and with leaves of the same type, as the PyTree being loaded. Those leaves which are loaded will replace the corresponding leaves oflike
.filter_spec
: Specifies how to load each kind of leaf. By default all JAX arrays, NumPy arrays, Python bool/int/float/complexes are loaded, and all other leaf types are not loaded, and will retain their value fromlike
. (Seeequinox.default_deserialise_filter_spec
.)is_leaf
: Called on every node oflike
; ifTrue
then this node will be treated as a leaf.
Returns:
The loaded PyTree, formed by iterating over like
and replacing some of its leaves
with the leaves saved in path
.
Example
This can be used to load a model from file.
import equinox as eqx
import jax.random as jr
model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model_original)
model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)
# To partially load weights, do model surgery. In this case load everything
# except the final layer.
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)
Example
A common pattern is the following:
def run(..., load_path=None):
if load_path is None:
model = Model(...hyperparameters...)
else:
model = eqx.filter_eval_shape(Model, ...hyperparameters...)
model = eqx.tree_deserialise_leaves(load_path, model)
like
is constructed (e.g. when resuming training), where
equinox.filter_eval_shape
is used to avoid creating spurious short-lived
arrays taking up memory.
Info
filter_spec
should typically be a function (File, Any) -> Any
, which takes
a file handle and a leaf from like
, and either returns the corresponding
loaded leaf, or returns the leaf from like
unchanged.
It can also be a PyTree of such functions, in which case the PyTree structure
should be a prefix of pytree
, and each function will be mapped over the
corresponding sub-PyTree of pytree
.
equinox.default_serialise_filter_spec(f: BinaryIO, x: Any) -> None
¤
Default filter specification for serialising a leaf.
Arguments
f
: file-like objectx
: The leaf to be saved on the disk.
Returns
Nothing.
Info
This function can be extended to customise the serialisation behaviour for leaves.
Example
Skipping saving of jax.Array.
import jax.numpy as jnp
import equinox as eqx
tree = (jnp.array([1,2,3]), [4,5,6])
new_filter_spec = lambda f,x: (
None if isinstance(x, jax.Array) else eqx.default_serialise_filter_spec(f, x)
)
eqx.tree_serialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)
equinox.default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any
¤
Default filter specification for deserialising saved data.
Arguments
f
: file-like objectx
: The leaf for which the data needs to be loaded.
Returns
The new value for datatype x
.
Info
This function can be extended to customise the deserialisation behaviour for leaves.
Example
Skipping loading of jax.Array.
import jax.numpy as jnp
import equinox as eqx
tree = (jnp.array([4,5,6]), [1,2,3])
new_filter_spec = lambda f,x: (
x if isinstance(x, jax.Array) else eqx.default_deserialise_filter_spec(f, x)
)
new_tree = eqx.tree_deserialise_leaves("some_filename.eqx", tree, filter_spec=new_filter_spec)