quax.examples.lora¤
As a (actually quite useful) tech-demo, Quax provides an implementation of LoRA: Low-Rank Adaptation, which is a popular fine-tuning method for large neural network models.
Most of the time you will just need the quax.examples.lora.loraify
function, which transforms an existing Equinox model.
For a user who only wants to LoRA'ify only part of their model, the underlying quax.examples.lora.LoraArray
array-ish object (which subclasses quax.ArrayValue
) is also available.
quax.examples.lora.loraify(model: PyTree, *, rank: int, scale: float = 0.01, allow_materialise: bool = False, stop_gradient: bool = True, key: PRNGKeyArray) -> PyTree
¤
Converts an Equinox model into a low-rank adapted version.
Arguments:
model
: the model to convert. This is treated as a PyTree, and alleqx.nn.Linear
layers found will have their weight matrices replaced withLoraArray
s.rank
: the rank of the low-rank adaptation.scale
: how large to initialise thea
matrix of the low-rank adaptation.allow_materialise
: if Quax encounters an operation for which there has not been a specific override specified forLoraArray
s, should it either (a) throw an error (allow_materialise=False
, the default), or (b) silently convert theLoraArray
back into an JAX array, by explicitly calculatingw + a @ b
(allow_materialise=True
).stop_gradient
: whether to automatically stop the gradient (prevent training) of the original weight matrices of the linear layers.key
: used to provide randomness for initialising the low-rank adaptation.
Returns:
A copy of model
, will all linear layers having their weight matrices replaced with
LoraArray
s.
Typically, the result should then be used with a call to quax.quaxify
, which will
trace your JAX program, and replace all interactions with LoRA arrays using the
appropriate multiple dispatch rules.
Example
import equinox as eqx
import quax
import jax.random as jr
key = jr.PRNGKey(0)
mlp = eqx.nn.MLP(...)
mlp = quax.lora.loraify(mlp, rank=2, key=key)
# Wrap in `quaxify` and call as normal.
some_output = quax.quaxify(mlp)(some_input)
quax.examples.lora.LoraArray (ArrayValue)
¤
Replaces a matrix w in R^{n x m}
with w + a @ b
, where a in R^{n x k}
and
b in R^{k x m}
.
Typically k
is much smaller than n
or m
, and so w + a @ b
is described as a
"low rank adaptation" of w
. The value of k
is the "rank" of the adaptation.
Note that this does not materialise the sum w + a @ b
into a single matrix, but
instead stores it as three separate w
, a
, b
matrices. This is because the
typical use-case for LoRA is to update just the a
and b
matrices when
fine-tuning a neural network.
This implementation makes use of Quax's multiple-dispatch capabilities to calculate
matrix-vector products (w + a @ b) @ x
via w @ x + a @ (b @ x)
, which turns out
to be computationally cheaper.
__init__(self, weight: Array, *, rank: int, scale: float = 0.01, allow_materialise: bool = False, stop_gradient: bool = True, key: PRNGKeyArray)
¤
Arguments:
weight
: the original weight to wrap.rank
: the rank of the low-rank adaptation.scale
:a
will be initialised atNormal(0, scale^2)
. (b
is initialised at zero.)allow_materialise
: if Quax encounters an operation for which there has not been a specific override specified for LoraArrays, should it either (a) throw an error (allow_materialise=False
, the default), or (b) silently convert theLoraArray
back into an JAX array, by explicitly calculatingw + a @ b
(allow_materialise=True
).stop_gradient
: whether to automatically stop the gradient (prevent training) of the original weight matrixweight
.key
: used to provide randomness for initialisinga
.
Example¤
Here's a copy of the LoRA example from the README again:
import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora
#
# Start off with any JAX program: here, the forward pass through a linear layer.
#
key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))
def run(model, x):
return model(x)
run(linear, vector) # can call this as normal
#
# Now let's Lora-ify it.
#
# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)