# quax.examples.lora¤

As a (actually quite useful) tech-demo, Quax provides an implementation of LoRA: Low-Rank Adaptation, which is a popular fine-tuning method for large neural network models.

Most of the time you will just need the `quax.examples.lora.loraify`

function, which transforms an existing Equinox model.

For a user who only wants to LoRA'ify only part of their model, the underlying `quax.examples.lora.LoraArray`

array-ish object (which subclasses `quax.ArrayValue`

) is also available.

####
`quax.examples.lora.loraify(model: PyTree, *, rank: int, scale: float = 0.01, allow_materialise: bool = False, stop_gradient: bool = True, key: PRNGKeyArray) -> PyTree`

¤

Converts an Equinox model into a low-rank adapted version.

**Arguments:**

`model`

: the model to convert. This is treated as a PyTree, and all`eqx.nn.Linear`

layers found will have their weight matrices replaced with`LoraArray`

s.`rank`

: the rank of the low-rank adaptation.`scale`

: how large to initialise the`a`

matrix of the low-rank adaptation.`allow_materialise`

: if Quax encounters an operation for which there has not been a specific override specified for`LoraArray`

s, should it either (a) throw an error (`allow_materialise=False`

, the default), or (b) silently convert the`LoraArray`

back into an JAX array, by explicitly calculating`w + a @ b`

(`allow_materialise=True`

).`stop_gradient`

: whether to automatically stop the gradient (prevent training) of the original weight matrices of the linear layers.`key`

: used to provide randomness for initialising the low-rank adaptation.

**Returns:**

A copy of `model`

, will all linear layers having their weight matrices replaced with
`LoraArray`

s.

Typically, the result should then be used with a call to `quax.quaxify`

, which will
trace your JAX program, and replace all interactions with LoRA arrays using the
appropriate multiple dispatch rules.

Example

```
import equinox as eqx
import quax
import jax.random as jr
key = jr.PRNGKey(0)
mlp = eqx.nn.MLP(...)
mlp = quax.lora.loraify(mlp, rank=2, key=key)
# Wrap in `quaxify` and call as normal.
some_output = quax.quaxify(mlp)(some_input)
```

####
```
quax.examples.lora.LoraArray (ArrayValue)
```

¤

Replaces a matrix `w in R^{n x m}`

with `w + a @ b`

, where `a in R^{n x k}`

and
`b in R^{k x m}`

.

Typically `k`

is much smaller than `n`

or `m`

, and so `w + a @ b`

is described as a
"low rank adaptation" of `w`

. The value of `k`

is the "rank" of the adaptation.

Note that this does not materialise the sum `w + a @ b`

into a single matrix, but
instead stores it as three separate `w`

, `a`

, `b`

matrices. This is because the
typical use-case for LoRA is to update just the `a`

and `b`

matrices when
fine-tuning a neural network.

This implementation makes use of Quax's multiple-dispatch capabilities to calculate
matrix-vector products `(w + a @ b) @ x`

via `w @ x + a @ (b @ x)`

, which turns out
to be computationally cheaper.

#####
`__init__(self, weight: Array, *, rank: int, scale: float = 0.01, allow_materialise: bool = False, stop_gradient: bool = True, key: PRNGKeyArray)`

¤

**Arguments:**

`weight`

: the original weight to wrap.`rank`

: the rank of the low-rank adaptation.`scale`

:`a`

will be initialised at`Normal(0, scale^2)`

. (`b`

is initialised at zero.)`allow_materialise`

: if Quax encounters an operation for which there has not been a specific override specified for LoraArrays, should it either (a) throw an error (`allow_materialise=False`

, the default), or (b) silently convert the`LoraArray`

back into an JAX array, by explicitly calculating`w + a @ b`

(`allow_materialise=True`

).`stop_gradient`

: whether to automatically stop the gradient (prevent training) of the original weight matrix`weight`

.`key`

: used to provide randomness for initialising`a`

.

## Example¤

Here's a copy of the LoRA example from the README again:

```
import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora
#
# Start off with any JAX program: here, the forward pass through a linear layer.
#
key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))
def run(model, x):
return model(x)
run(linear, vector) # can call this as normal
#
# Now let's Lora-ify it.
#
# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)
```