Advanced tutorial - rules that redispatch back to Quax¤
In our previous two examples here and here, all of our registered rules had our custom type interacting either with itself, or with ArrayLike
s.
We can also arrange to have them interact with other Quax types, including ones that are authored by someone else, and which we don't know anything about! The key trick to this is to implement the part of the rule that we care about -- and then redispatch back to Quax to handle the other type(s).
from typing import Union
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
from jaxtyping import ( # https://github.com/patrick-kidger/quax
Array,
ArrayLike,
Int,
Shaped,
)
import quax
Definitions¤
# Here's a rank-1 LoRA. This is basically a simple version of
# `quax.examples.lora.LoraArray`.
class LoraArray(quax.ArrayValue):
w: Shaped[Array, "dim1 dim2"]
a: Shaped[Array, " dim1"]
b: Shaped[Array, " dim2"]
def aval(self):
shape = jnp.shape(self.w)
dtype = jnp.result_type(self.w)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
raise ValueError("Refusing to materialise `LoraArray`")
def _lora_matmul(
w: Shaped[Array, "dim1 dim2"],
a: Shaped[Array, " dim1"],
b: Shaped[Array, " dim2"],
y: Shaped[Array, " dim2"],
) -> Shaped[Array, " dim1"]:
return w @ y + a * jnp.dot(b, y)
@quax.register(jax.lax.dot_general_p)
def _(
x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params
):
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
if jnp.ndim(x) != 2 and jnp.ndim(y) != 1:
raise NotImplementedError(
"Have not implemented dot_general except for matrix-vector products"
)
if (
lhs_batch == ()
and rhs_batch == ()
and lhs_contract == (1,)
and rhs_contract == (0,)
):
# redispatch based on the type of `y`!
return quax.quaxify(_lora_matmul)(x.w, x.a, x.b, y)
else:
raise NotImplementedError(
f"Have not implemented dot_general for {dimension_numbers}."
)
Notice how we haven't just allowed y: ArrayLike
, but we have also allowed other Quax types as well! We've then redispatched based on the type of y
.
So first of all, let's check that the usual ArrayLike
argument still works as before:
matmul = lambda a, b: a @ b
w = jnp.ones((3, 4))
a = jnp.ones(3)
b = jnp.ones(4)
lora_array = LoraArray(w, a, b)
y = jnp.arange(4.0)
quax.quaxify(matmul)(lora_array, y)
Redispatching¤
And now, let's check that we really can redispatch against another custom type. We're going to do
quax.quaxify(matmul, LoraArray(...), SomeKindOfSparseVector(...))
So let's go ahead and do that quickly! Pretend that LoraArray
and SomeKindOfSparseVector
are implemented by two different people in two different codebases.
But because the LoraArray
implementation redispatches, then things "just work" without any need for special-casing compatibility between the two types.
class SomeKindOfSparseVector(quax.ArrayValue):
"""This represents a sparse vector with a single non-zero value."""
index: Int[ArrayLike, ""]
value: Shaped[ArrayLike, ""]
length: int = eqx.field(static=True)
def aval(self):
shape = (self.length,)
dtype = jnp.result_type(self.value)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
raise ValueError("Refusing to materialise `SomeKindOfSparseVector`")
@quax.register(jax.lax.dot_general_p)
def _(x: Array, y: SomeKindOfSparseVector, *, dimension_numbers, **params):
if jnp.ndim(x) == 1:
(length,) = x.shape
if length != y.length:
raise ValueError("Mismatched vector shapes")
return x[y.index] * y.value
elif jnp.ndim(x) == 2:
rows, cols = x.shape
if cols != y.length:
raise ValueError("Mismatched matrix and vector shapes")
return x[:, y.index] * y.value
else:
raise NotImplementedError(
"Have not implemented dot_general except for matrix-vector products"
)
sparse_vector = SomeKindOfSparseVector(index=2, value=5, length=4)
quax.quaxify(matmul)(lora_array, sparse_vector)
The key-takeaway here is that if you want to handle arbitrary Quax types:
- Make the type annotation be Union[ArrayLike, quax.ArrayValue]
,
- and redispatch with a nested quax.quaxify
call!
Ambiguous lookup errors¤
When playing in these advanced waters, there is one possible failure mode to be aware of. Suppose the registration rule for SomeKindOfSparseVector
looked like this instead:
@quax.register(jax.lax.dot_general_p)
def _(x: Union[ArrayLike, quax.ArrayValue], y: SomeKindOfSparseVector, *, dimension_numbers, **params):
quax.ArrayValue
.
Then, how should the top-level quax.quaxify(matmul)(lora_array, sparse_vector)
work? Should the matmul bind against the above rule (which is valid as LoraArray
is a subclass of quax.ArrayValue
, and SomeKindOfSparseVector
matches exactly), or should it bind against the
@quax.register(jax.lax.dot_general_p)
def _(x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params):
LoraArray
matches exactly, and SomeKindOfSparseVector
is a subclass of quax.ArrayValue
)?
In this case, due to the ambiguity, an AmbiguousLookupError
will be raised! Let's experiment by doing that now, overwriting our previously-registered rule:
@quax.register(jax.lax.dot_general_p)
def _(
x: Union[ArrayLike, quax.ArrayValue],
y: SomeKindOfSparseVector,
*,
dimension_numbers,
**params,
):
if jnp.ndim(x) == 1:
(length,) = x.shape
if length != y.length:
raise ValueError("Mismatched vector shapes")
return x[y.index] * y.value
elif jnp.ndim(x) == 2:
rows, cols = x.shape
if cols != y.length:
raise ValueError("Mismatched matrix and vector shapes")
return x[:, y.index] * y.value
else:
raise NotImplementedError(
"Have not implemented dot_general except for matrix-vector products"
)
try:
quax.quaxify(matmul)(lora_array, sparse_vector)
except Exception as e:
print(repr(e))
We have two ways of fixing this.
Solution 1: nested quaxifies¤
The first (the preferred way) is to do a nested quax.quaxify
.
Under the hood, quax.quaxify(fn, filter_spec)(*args, **kwargs)
will run dynamic, static = eqx.partition((fn, args, kwargs), filter_spec)
, and then it will only quaxify those argments in dynamic
, whilst those in static
will be left untouched.
So in this case, we started off with a matmul
that does array @ vector
. The inner quaxify
turns that into a function that's defined to do lora_array @ vector
. The outer quaxify then turns that into a function that's defined to do lora_array @ sparse_vector
. This means that we now have an unambiguous lookup order: by construction (from our inner quaxify
) we've specified that we want to use the
@quax.register(jax.lax.dot_general_p)
def _(x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params):
is_lora_array = lambda x: isinstance(x, LoraArray)
is_sparse = lambda x: isinstance(x, SomeKindOfSparseVector)
quax.quaxify(quax.quaxify(matmul, filter_spec=is_lora_array), filter_spec=is_sparse)(
lora_array, sparse_vector
)
Note
Incidentally, from the behaviour of eqx.partition
, we could also have passed (False, (True, False), False)
for the inner filter_spec
, and (False, (False, True), False)
for the outer filter_spec
: this will explicitly pick out the LoRA and sparse objects by position, rather than by type.
Note
The order of these two quaxifies is important. If we'd done it the other way around, then we would have hit the ArrayValue @ SomeKindOfSparseVector
combination first. However, that involves indexing (x[:, y.index]
), and we (a) haven't provided an override for that operation for LoraArray
, and (b) have disallowed materialising the LoraArray
. So if we'd switched the quaxifies, we would have gotten a trace-time error.
Solution 2: override the combination¤
Okay, on to the second (less preferred) way: we can explicitly define an override rule for the combination:
@quax.register(jax.lax.dot_general_p)
def _(
x: LoraArray, y: SomeKindOfSparseVector, *, dimension_numbers, **params
): ... # some implementation here
However this is discouraged as (a) it involves mutating global state (the multiple dispatch lookup table), which could potentially have effects in other parts of your codebase or in your libraries, and (b) it means that you have to figure out the implementation for this combination.