Sharing layers¤
equinox.nn.Shared (Module)
¤
Used to tie together multiple nodes across a PyTree.
Note that Equinox modules are PyTrees -- so the same layer, appearing in two difference parts of the tree, will be treated as two copies of this layer. For example,
class SubModel(eqx.Module):
linear: eqx.nn.Linear
class Model(eqx.Module):
linear: eqx.nn.Linear
submodel: SubModel
def __init__(self):
linear = eqx.nn.Linear(...)
self.linear = linear
self.submodel = SubModel(linear)
model.linear
and model.submodel.linear
as two separate
layers. They will start with the same initial parameter values, and then update
independently during training.
For when we really do want to share layers or weights across different parts of a
model, then eqx.nn.Shared
exists as a way to easily express this in the PyTree
paradigm.
Example
It is common in many language models to have an initial embedding matrix at the start, and then to reuse this as the weight of the final linear transformation.
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Int
class LanguageModel(eqx.Module):
shared: eqx.nn.Shared
def __init__(self):
embedding = eqx.nn.Embedding(...)
linear = eqx.nn.Linear(...)
# These two weights will now be tied together.
where = lambda embed_and_lin: embed_and_lin[1].weight
get = lambda embed_and_lin: embed_and_lin[0].weight
self.shared = eqx.nn.Shared((embedding, linear), where, get)
def __call__(self, tokens: Int[Array, "sequence"]):
# Expand back out so we can evaluate these layers.
embedding, linear = self.shared()
assert embedding.weight is linear.weight # same parameter!
# Now go ahead and evaluate your language model.
values = jax.vmap(embedding)(tokens)
... # other layers, probably
return jax.vmap(linear)(values)
(Side note: you will sometimes see some authors referring to transposing
the embedding matrix prior to the final linear layer. This is because some
other libraries store the weight matrices of linear layers the other way
around. If that had been necessary here then we could have done it with
get = lambda embed_and_lin: jnp.transpose(embed_and_lin[0].weight)
.)
__init__(self, pytree: PyTree, where: Callable, get: Callable)
¤
Arguments:
pytree
: The PyTree to share some nodes across.where
: a function specifying either a single node, or a sequence of nodes, as witheqx.tree_at(where, pytree, ...)
.get
: a function, which when evaluated onpytree
, returns either a single value (ifwhere
does), or a sequence of values (ifwhere
does, and in this case this must be a sequence of the same length aswhere
).
The node(s) of get(pytree)
and the corresponding value(s) of where(pytree)
will be tied together.
Info
To explain how this works. The implementation is just:
class Shared(eqx.Module):
pytree: PyTree
where: Callable
get: Callable
def __init__(self, pytree, where, get):
# `0` is just some dummy value
self.pytree = eqx.tree_at(where, pytree, replace_fn=lambda _: 0)
self.where = where
self.get = get
def __call__(self):
return eqx.tree_at(self.where, self.pytree, self.get(self.pytree))
__init__
time, the duplicate nodes specified in where
are
removed from the PyTree. We no longer have a separate copy updating during
training.
And then at __call__
time, references to the values returned by
get(pytree)
are put in their place. We end up with a pytree of the same
structure as what we started with, which we can now use (evaluate as a
layer etc.) as normal.
Tip
If you need to apply any transform (e.g. transposing a matrix), then this
can be done as part of get
. For example,
get = lambda pair: jnp.transpose(pair[1].weight)
.
__call__(self)
¤
Arguments:
None.
Returns:
A PyTree of the same structure as the original pytree
, with get(pytree)
in
the place of the nodes at where(pytree)
.