Skip to content

Advanced tutorial - rules that redispatch back to Quax¤

In our previous two examples here and here, all of our registered rules had our custom type interacting either with itself, or with ArrayLikes.

We can also arrange to have them interact with other Quax types, including ones that are authored by someone else, and which we don't know anything about! The key trick to this is to implement the part of the rule that we care about -- and then redispatch back to Quax to handle the other type(s).

from typing import Union

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
from jaxtyping import (  # https://github.com/patrick-kidger/quax
    Array,
    ArrayLike,
    Int,
    Shaped,
)

import quax

Definitions¤

# Here's a rank-1 LoRA. This is basically a simple version of
# `quax.examples.lora.LoraArray`.


class LoraArray(quax.ArrayValue):
    w: Shaped[Array, "dim1 dim2"]
    a: Shaped[Array, " dim1"]
    b: Shaped[Array, " dim2"]

    def aval(self):
        shape = jnp.shape(self.w)
        dtype = jnp.result_type(self.w)
        return jax.core.ShapedArray(shape, dtype)

    def materialise(self):
        raise ValueError("Refusing to materialise `LoraArray`")


def _lora_matmul(
    w: Shaped[Array, "dim1 dim2"],
    a: Shaped[Array, " dim1"],
    b: Shaped[Array, " dim2"],
    y: Shaped[Array, " dim2"],
) -> Shaped[Array, " dim1"]:
    return w @ y + a * jnp.dot(b, y)


@quax.register(jax.lax.dot_general_p)
def _(
    x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params
):
    ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
    if jnp.ndim(x) != 2 and jnp.ndim(y) != 1:
        raise NotImplementedError(
            "Have not implemented dot_general except for matrix-vector products"
        )
    if (
        lhs_batch == ()
        and rhs_batch == ()
        and lhs_contract == (1,)
        and rhs_contract == (0,)
    ):
        # redispatch based on the type of `y`!
        return quax.quaxify(_lora_matmul)(x.w, x.a, x.b, y)
    else:
        raise NotImplementedError(
            f"Have not implemented dot_general for {dimension_numbers}."
        )

Notice how we haven't just allowed y: ArrayLike, but we have also allowed other Quax types as well! We've then redispatched based on the type of y.

So first of all, let's check that the usual ArrayLike argument still works as before:

matmul = lambda a, b: a @ b
w = jnp.ones((3, 4))
a = jnp.ones(3)
b = jnp.ones(4)
lora_array = LoraArray(w, a, b)
y = jnp.arange(4.0)
quax.quaxify(matmul)(lora_array, y)
Array([12., 12., 12.], dtype=float32)

Redispatching¤

And now, let's check that we really can redispatch against another custom type. We're going to do

quax.quaxify(matmul, LoraArray(...), SomeKindOfSparseVector(...))

So let's go ahead and do that quickly! Pretend that LoraArray and SomeKindOfSparseVector are implemented by two different people in two different codebases.

But because the LoraArray implementation redispatches, then things "just work" without any need for special-casing compatibility between the two types.

class SomeKindOfSparseVector(quax.ArrayValue):
    """This represents a sparse vector with a single non-zero value."""

    index: Int[ArrayLike, ""]
    value: Shaped[ArrayLike, ""]
    length: int = eqx.field(static=True)

    def aval(self):
        shape = (self.length,)
        dtype = jnp.result_type(self.value)
        return jax.core.ShapedArray(shape, dtype)

    def materialise(self):
        raise ValueError("Refusing to materialise `SomeKindOfSparseVector`")


@quax.register(jax.lax.dot_general_p)
def _(x: Array, y: SomeKindOfSparseVector, *, dimension_numbers, **params):
    if jnp.ndim(x) == 1:
        (length,) = x.shape
        if length != y.length:
            raise ValueError("Mismatched vector shapes")
        return x[y.index] * y.value
    elif jnp.ndim(x) == 2:
        rows, cols = x.shape
        if cols != y.length:
            raise ValueError("Mismatched matrix and vector shapes")
        return x[:, y.index] * y.value
    else:
        raise NotImplementedError(
            "Have not implemented dot_general except for matrix-vector products"
        )


sparse_vector = SomeKindOfSparseVector(index=2, value=5, length=4)
quax.quaxify(matmul)(lora_array, sparse_vector)
Array([10., 10., 10.], dtype=float32)

The key-takeaway here is that if you want to handle arbitrary Quax types: - Make the type annotation be Union[ArrayLike, quax.ArrayValue], - and redispatch with a nested quax.quaxify call!

Ambiguous lookup errors¤

When playing in these advanced waters, there is one possible failure mode to be aware of. Suppose the registration rule for SomeKindOfSparseVector looked like this instead:

@quax.register(jax.lax.dot_general_p)
def _(x: Union[ArrayLike, quax.ArrayValue], y: SomeKindOfSparseVector, *, dimension_numbers, **params):
where the first argument can be a quax.ArrayValue.

Then, how should the top-level quax.quaxify(matmul)(lora_array, sparse_vector) work? Should the matmul bind against the above rule (which is valid as LoraArray is a subclass of quax.ArrayValue, and SomeKindOfSparseVector matches exactly), or should it bind against the

@quax.register(jax.lax.dot_general_p)
def _(x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params):
rule we defined earlier (which is valid as LoraArray matches exactly, and SomeKindOfSparseVector is a subclass of quax.ArrayValue)?

In this case, due to the ambiguity, an AmbiguousLookupError will be raised! Let's experiment by doing that now, overwriting our previously-registered rule:

@quax.register(jax.lax.dot_general_p)
def _(
    x: Union[ArrayLike, quax.ArrayValue],
    y: SomeKindOfSparseVector,
    *,
    dimension_numbers,
    **params,
):
    if jnp.ndim(x) == 1:
        (length,) = x.shape
        if length != y.length:
            raise ValueError("Mismatched vector shapes")
        return x[y.index] * y.value
    elif jnp.ndim(x) == 2:
        rows, cols = x.shape
        if cols != y.length:
            raise ValueError("Mismatched matrix and vector shapes")
        return x[:, y.index] * y.value
    else:
        raise NotImplementedError(
            "Have not implemented dot_general except for matrix-vector products"
        )


try:
    quax.quaxify(matmul)(lora_array, sparse_vector)
except Exception as e:
    print(repr(e))
AmbiguousLookupError('For function `dot_general_dispatcher`, `(LoraArray(w=f32[3,4], a=f32[3], b=f32[4]), SomeKindOfSparseVector(index=2, value=5, length=4))` is ambiguous among the following:\n  Signature(__main__.LoraArray, typing.Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, quax._core.ArrayValue], implementation=<function _ at 0x7faa24241f80>) (precedence: 0)\n  Signature(typing.Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex, quax._core.ArrayValue], __main__.SomeKindOfSparseVector, implementation=<function _ at 0x7faa242e0360>) (precedence: 0)')

We have two ways of fixing this.

Solution 1: nested quaxifies¤

The first (the preferred way) is to do a nested quax.quaxify.

Under the hood, quax.quaxify(fn, filter_spec)(*args, **kwargs) will run dynamic, static = eqx.partition((fn, args, kwargs), filter_spec), and then it will only quaxify those argments in dynamic, whilst those in static will be left untouched.

So in this case, we started off with a matmul that does array @ vector. The inner quaxify turns that into a function that's defined to do lora_array @ vector. The outer quaxify then turns that into a function that's defined to do lora_array @ sparse_vector. This means that we now have an unambiguous lookup order: by construction (from our inner quaxify) we've specified that we want to use the

@quax.register(jax.lax.dot_general_p)
def _(x: LoraArray, y: Union[ArrayLike, quax.ArrayValue], *, dimension_numbers, **params):
rule first.

is_lora_array = lambda x: isinstance(x, LoraArray)
is_sparse = lambda x: isinstance(x, SomeKindOfSparseVector)
quax.quaxify(quax.quaxify(matmul, filter_spec=is_lora_array), filter_spec=is_sparse)(
    lora_array, sparse_vector
)
Array([10., 10., 10.], dtype=float32)

Note

Incidentally, from the behaviour of eqx.partition, we could also have passed (False, (True, False), False) for the inner filter_spec, and (False, (False, True), False) for the outer filter_spec: this will explicitly pick out the LoRA and sparse objects by position, rather than by type.

Note

The order of these two quaxifies is important. If we'd done it the other way around, then we would have hit the ArrayValue @ SomeKindOfSparseVector combination first. However, that involves indexing (x[:, y.index]), and we (a) haven't provided an override for that operation for LoraArray, and (b) have disallowed materialising the LoraArray. So if we'd switched the quaxifies, we would have gotten a trace-time error.

Solution 2: override the combination¤

Okay, on to the second (less preferred) way: we can explicitly define an override rule for the combination:

@quax.register(jax.lax.dot_general_p)
def _(
    x: LoraArray, y: SomeKindOfSparseVector, *, dimension_numbers, **params
): ...  # some implementation here

However this is discouraged as (a) it involves mutating global state (the multiple dispatch lookup table), which could potentially have effects in other parts of your codebase or in your libraries, and (b) it means that you have to figure out the implementation for this combination.