Skip to content

quax.examples.lora¤

As a (actually quite useful) tech-demo, Quax provides an implementation of LoRA: Low-Rank Adaptation, which is a popular fine-tuning method for large neural network models.

Most of the time you will just need the quax.examples.lora.loraify function, which transforms an existing Equinox model.

For a user who only wants to LoRA'ify only part of their model, the underlying quax.examples.lora.LoraArray array-ish object (which subclasses quax.ArrayValue) is also available.


quax.examples.lora.loraify(model: PyTree, *, rank: int, scale: float = 0.01, allow_materialise: bool = False, stop_gradient: bool = True, key: PRNGKeyArray) -> PyTree ¤

Converts an Equinox model into a low-rank adapted version.

Arguments:

  • model: the model to convert. This is treated as a PyTree, and all eqx.nn.Linear layers found will have their weight matrices replaced with LoraArrays.
  • rank: the rank of the low-rank adaptation.
  • scale: how large to initialise the a matrix of the low-rank adaptation.
  • allow_materialise: if Quax encounters an operation for which there has not been a specific override specified for LoraArrays, should it either (a) throw an error (allow_materialise=False, the default), or (b) silently convert the LoraArray back into an JAX array, by explicitly calculating w + a @ b (allow_materialise=True).
  • stop_gradient: whether to automatically stop the gradient (prevent training) of the original weight matrices of the linear layers.
  • key: used to provide randomness for initialising the low-rank adaptation.

Returns:

A copy of model, will all linear layers having their weight matrices replaced with LoraArrays.

Typically, the result should then be used with a call to quax.quaxify, which will trace your JAX program, and replace all interactions with LoRA arrays using the appropriate multiple dispatch rules.

Example

import equinox as eqx
import quax
import jax.random as jr

key = jr.PRNGKey(0)
mlp = eqx.nn.MLP(...)
mlp = quax.lora.loraify(mlp, rank=2, key=key)
# Wrap in `quaxify` and call as normal.
some_output = quax.quaxify(mlp)(some_input)

quax.examples.lora.LoraArray (ArrayValue) ¤

Replaces a matrix w in R^{n x m} with w + a @ b, where a in R^{n x k} and b in R^{k x m}.

Typically k is much smaller than n or m, and so w + a @ b is described as a "low rank adaptation" of w. The value of k is the "rank" of the adaptation.

Note that this does not materialise the sum w + a @ b into a single matrix, but instead stores it as three separate w, a, b matrices. This is because the typical use-case for LoRA is to update just the a and b matrices when fine-tuning a neural network.

This implementation makes use of Quax's multiple-dispatch capabilities to calculate matrix-vector products (w + a @ b) @ x via w @ x + a @ (b @ x), which turns out to be computationally cheaper.

__init__(self, weight: Array, *, rank: int, scale: float = 0.01, allow_materialise: bool = False, stop_gradient: bool = True, key: PRNGKeyArray) ¤

Arguments:

  • weight: the original weight to wrap.
  • rank: the rank of the low-rank adaptation.
  • scale: a will be initialised at Normal(0, scale^2). (b is initialised at zero.)
  • allow_materialise: if Quax encounters an operation for which there has not been a specific override specified for LoraArrays, should it either (a) throw an error (allow_materialise=False, the default), or (b) silently convert the LoraArray back into an JAX array, by explicitly calculating w + a @ b (allow_materialise=True).
  • stop_gradient: whether to automatically stop the gradient (prevent training) of the original weight matrix weight.
  • key: used to provide randomness for initialising a.

Example¤

Here's a copy of the LoRA example from the README again:

import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
  return model(x)

run(linear, vector)  # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)