# Advanced tutorial - rules that redispatch back to Quax¤

In our previous two examples here and here, all of our registered rules had our custom type interacting either with itself, or with `ArrayLike`

s.

We can also arrange to have them interact with other Quax types, including ones that are authored by someone else, and which we don't know anything about! The key trick to this is to implement the part of the rule that we care about -- and then redispatch back to Quax to handle the other type(s).

```
from typing import Union
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
from jaxtyping import ( # https://github.com/patrick-kidger/quax
Array,
ArrayLike,
Int,
Shaped,
)
import quax
```

## Definitions¤

```
# Here's a rank-1 LoRA. This is basically a simple version of
# `quax.examples.lora.LoraArray`.
class LoraArray(quax.ArrayValue):
w: Shaped[Array, "dim1 dim2"]
a: Shaped[Array, " dim1"]
b: Shaped[Array, " dim2"]
def aval(self):
shape = jnp.shape(self.w)
dtype = jnp.result_type(self.w)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
raise ValueError("Refusing to materialise `LoraArray`")
def _lora_matmul(
w: Shaped[Array, "dim1 dim2"],
a: Shaped[Array, " dim1"],
b: Shaped[Array, " dim2"],
y: Shaped[Array, " dim2"],
) -> Shaped[Array, " dim1"]:
return w @ y + a * jnp.dot(b, y)
@quax.register(jax.lax.dot_general_p)
def _(
x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params
):
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
if jnp.ndim(x) != 2 and jnp.ndim(y) != 1:
raise NotImplementedError(
"Have not implemented dot_general except for matrix-vector products"
)
if (
lhs_batch == ()
and rhs_batch == ()
and lhs_contract == (1,)
and rhs_contract == (0,)
):
# redispatch based on the type of `y`!
return quax.quaxify(_lora_matmul)(x.w, x.a, x.b, y)
else:
raise NotImplementedError(
f"Have not implemented dot_general for {dimension_numbers}."
)
```

Notice how we haven't just allowed `y: ArrayLike`

, but we have also allowed other Quax types as well! We've then redispatched based on the type of `y`

.

So first of all, let's check that the usual `ArrayLike`

argument still works as before:

```
matmul = lambda a, b: a @ b
w = jnp.ones((3, 4))
a = jnp.ones(3)
b = jnp.ones(4)
lora_array = LoraArray(w, a, b)
y = jnp.arange(4.0)
quax.quaxify(matmul)(lora_array, y)
```

## Redispatching¤

And now, let's check that we really can redispatch against another custom type. We're going to do

```
quax.quaxify(matmul, LoraArray(...), SomeKindOfSparseVector(...))
```

So let's go ahead and do that quickly! Pretend that `LoraArray`

and `SomeKindOfSparseVector`

are implemented by two different people in two different codebases.

But because the `LoraArray`

implementation redispatches, then things "just work" without any need for special-casing compatibility between the two types.

```
class SomeKindOfSparseVector(quax.ArrayValue):
"""This represents a sparse vector with a single non-zero value."""
index: Int[ArrayLike, ""]
value: Shaped[ArrayLike, ""]
length: int = eqx.field(static=True)
def aval(self):
shape = (self.length,)
dtype = jnp.result_type(self.value)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
raise ValueError("Refusing to materialise `SomeKindOfSparseVector`")
@quax.register(jax.lax.dot_general_p)
def _(x: Array, y: SomeKindOfSparseVector, *, dimension_numbers, **params):
if jnp.ndim(x) == 1:
(length,) = x.shape
if length != y.length:
raise ValueError("Mismatched vector shapes")
return x[y.index] * y.value
elif jnp.ndim(x) == 2:
rows, cols = x.shape
if cols != y.length:
raise ValueError("Mismatched matrix and vector shapes")
return x[:, y.index] * y.value
else:
raise NotImplementedError(
"Have not implemented dot_general except for matrix-vector products"
)
sparse_vector = SomeKindOfSparseVector(index=2, value=5, length=4)
quax.quaxify(matmul)(lora_array, sparse_vector)
```

The key-takeaway here is that if you want to handle arbitrary Quax types:
- Make the type annotation be `Union[ArrayLike, quax.ArrayValue]`

,
- and redispatch with a nested `quax.quaxify`

call!

## Ambiguous lookup errors¤

When playing in these advanced waters, there is one possible failure mode to be aware of. Suppose the registration rule for `SomeKindOfSparseVector`

looked like this instead:

```
@quax.register(jax.lax.dot_general_p)
def _(x: Union[ArrayLike, quax.ArrayValue], y: SomeKindOfSparseVector, *, dimension_numbers, **params):
```

`quax.ArrayValue`

.
Then, how should the top-level `quax.quaxify(matmul)(lora_array, sparse_vector)`

work? Should the matmul bind against the above rule (which is valid as `LoraArray`

is a subclass of `quax.ArrayValue`

, and `SomeKindOfSparseVector`

matches exactly), or should it bind against the

```
@quax.register(jax.lax.dot_general_p)
def _(x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params):
```

`LoraArray`

matches exactly, and `SomeKindOfSparseVector`

is a subclass of `quax.ArrayValue`

)?
In this case, due to the ambiguity, an `AmbiguousLookupError`

will be raised! Let's experiment by doing that now, overwriting our previously-registered rule:

```
@quax.register(jax.lax.dot_general_p)
def _(
x: Union[ArrayLike, quax.ArrayValue],
y: SomeKindOfSparseVector,
*,
dimension_numbers,
**params,
):
if jnp.ndim(x) == 1:
(length,) = x.shape
if length != y.length:
raise ValueError("Mismatched vector shapes")
return x[y.index] * y.value
elif jnp.ndim(x) == 2:
rows, cols = x.shape
if cols != y.length:
raise ValueError("Mismatched matrix and vector shapes")
return x[:, y.index] * y.value
else:
raise NotImplementedError(
"Have not implemented dot_general except for matrix-vector products"
)
try:
quax.quaxify(matmul)(lora_array, sparse_vector)
except Exception as e:
print(repr(e))
```

We have two ways of fixing this.

### Solution 1: nested quaxifies¤

The first (the preferred way) is to do a nested `quax.quaxify`

.

Under the hood, `quax.quaxify(fn, filter_spec)(*args, **kwargs)`

will run `dynamic, static = eqx.partition((fn, args, kwargs), filter_spec)`

, and then it will only quaxify those argments in `dynamic`

, whilst those in `static`

will be left untouched.

So in this case, we started off with a `matmul`

that does `array @ vector`

. The inner `quaxify`

turns that into a function that's defined to do `lora_array @ vector`

. The outer quaxify then turns that into a function that's defined to do `lora_array @ sparse_vector`

. This means that we now have an unambiguous lookup order: by construction (from our inner `quaxify`

) we've specified that we want to use the

```
@quax.register(jax.lax.dot_general_p)
def _(x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params):
```

```
is_lora_array = lambda x: isinstance(x, LoraArray)
is_sparse = lambda x: isinstance(x, SomeKindOfSparseVector)
quax.quaxify(quax.quaxify(matmul, filter_spec=is_lora_array), filter_spec=is_sparse)(
lora_array, sparse_vector
)
```

Note

Incidentally, from the behaviour of `eqx.partition`

, we could also have passed `(False, (True, False), False)`

for the inner `filter_spec`

, and `(False, (False, True), False)`

for the outer `filter_spec`

: this will explicitly pick out the LoRA and sparse objects by position, rather than by type.

Note

The order of these two quaxifies is important. If we'd done it the other way around, then we would have hit the `ArrayValue @ SomeKindOfSparseVector`

combination first. However, that involves indexing (`x[:, y.index]`

), and we (a) haven't provided an override for that operation for `LoraArray`

, and (b) have disallowed materialising the `LoraArray`

. So if we'd switched the quaxifies, we would have gotten a trace-time error.

### Solution 2: override the combination¤

Okay, on to the second (less preferred) way: we can explicitly define an override rule for the combination:

```
@quax.register(jax.lax.dot_general_p)
def _(
x: LoraArray, y: SomeKindOfSparseVector, *, dimension_numbers, **params
): ... # some implementation here
```

However this is discouraged as (a) it involves mutating global state (the multiple dispatch lookup table), which could potentially have effects in other parts of your codebase or in your libraries, and (b) it means that you have to figure out the implementation for this combination.