Used to tie together multiple nodes across a PyTree.
Note that Equinox modules are PyTrees -- so the same layer, appearing in two difference parts of the tree, will be treated as two copies of this layer. For example,
class SubModel(eqx.Module): linear: eqx.nn.Linear class Model(eqx.Module): linear: eqx.nn.Linear submodel: SubModel def __init__(self): linear = eqx.nn.Linear(...) self.linear = linear self.submodel = SubModel(linear)
model.submodel.linearas two separate layers. They will start with the same initial parameter values, and then update independently during training.
For when we really do want to share layers or weights across different parts of a
eqx.nn.Shared exists as a way to easily express this in the PyTree
It is common in many language models to have an initial embedding matrix at the start, and then to reuse this as the weight of the final linear transformation.
import equinox as eqx import jax.numpy as jnp from jaxtyping import Array, Int class LanguageModel(eqx.Module): shared: eqx.nn.Shared def __init__(self): embedding = eqx.nn.Embedding(...) linear = eqx.nn.Linear(...) # These two weights will now be tied together. where = lambda embed_and_lin: embed_and_lin.weight get = lambda embed_and_lin: embed_and_lin.weight self.shared = eqx.nn.Shared((embedding, linear), where, get) def __call__(self, tokens: Int[Array, "sequence"]): # Expand back out so we can evaluate these layers. embedding, linear = self.shared() assert embedding.weight is linear.weight # same parameter! # Now go ahead and evaluate your language model. values = jax.vmap(embedding)(tokens) ... # other layers, probably return jax.vmap(linear)(values)
(Side note: you will sometimes see some authors referring to transposing
the embedding matrix prior to the final linear layer. This is because some
other libraries store the weight matrices of linear layers the other way
around. If that had been necessary here then we could have done it with
get = lambda embed_and_lin: jnp.transpose(embed_and_lin.weight).)
__init__(self, pytree: PyTree, where: Callable, get: Callable)
pytree: The PyTree to share some nodes across.
where: a function specifying either a single node, or a sequence of nodes, as with
eqx.tree_at(where, pytree, ...).
get: a function, which when evaluated on
pytree, returns either a single value (if
wheredoes), or a sequence of values (if
wheredoes, and in this case this must be a sequence of the same length as
The node(s) of
get(pytree) and the corresponding value(s) of
will be tied together.
To explain how this works. The implementation is just:
class Shared(eqx.Module): pytree: PyTree where: Callable get: Callable def __init__(self, pytree, where, get): # `0` is just some dummy value self.pytree = eqx.tree_at(where, pytree, replace_fn=lambda _: 0) self.where = where self.get = get def __call__(self): return eqx.tree_at(self.where, self.pytree, self.get(self.pytree))
__init__time, the duplicate nodes specified in
whereare removed from the PyTree. We no longer have a separate copy updating during training.
And then at
__call__ time, references to the values returned by
get(pytree) are put in their place. We end up with a pytree of the same
structure as what we started with, which we can now use (evaluate as a
layer etc.) as normal.
If you need to apply any transform (e.g. transposing a matrix), then this
can be done as part of
get. For example,
get = lambda pair: jnp.transpose(pair.weight).
A PyTree of the same structure as the original
the place of the nodes at