Skip to content

Intermediate tutorial - default rulesยค

In the previous tutorial, we saw how to overload a particular primitive-type combination. What about when we only have the type, but want to overload every primitive? For this we have default rules.

Here's an example for a type that detects whether we're in the forward or backward pass of backpropagation.
(For example, this is useful with quantisation, for which we often want to quantise in different ways in each pass.)

In this example, we'll see how to create a custom array-ish Quax type. And in particular, we'll discuss a few important patterns when registering rules for a new type.

import functools as ft
from collections.abc import Sequence
from typing import Union

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import ArrayLike  # https://github.com/patrick-kidger/quax

import quax

We begin by writing a "tag" type that just wraps an array. Whenever it is used in a JAX operation, it'll promote all the other types into tagged arrays as well. This means we can see everything that happens downstream of a particular operation.

The interesting bit will be implementing a default rule for it.

class BackwardTag(quax.ArrayValue):
    array: ArrayLike

    def aval(self):
        shape = jnp.shape(self.array)
        dtype = jnp.result_type(self.array)
        return jax.core.ShapedArray(shape, dtype)

    @staticmethod
    def default(
        primitive: jax.core.Primitive,
        values: Sequence[Union[ArrayLike, quax.Value]],
        params: dict,
    ):
        raw_values: list[ArrayLike] = []
        for value in values:
            if eqx.is_array_like(value):
                raw_values.append(value)
            elif isinstance(value, BackwardTag):
                raw_values.append(value.array)
            elif isinstance(value, quax.Value):
                raise ValueError(
                    "`BackwardTag` cannot be used in conjuction with other Quax types."
                )
            else:
                assert False  # should never happen
        out = primitive.bind(*raw_values, **params)
        if primitive.multiple_results:
            return [BackwardTag(x) for x in out]
        else:
            return BackwardTag(out)

    def materialise(self):
        # See the documentation for `quax.ArrayValue.{default,materialise}`.
        # This shouldn't ever be called for us.
        raise ValueError("Refusing to materialise and remove `BackwardTag`")

Okay, in some sense that's actually the end of the "part 2 tutorial" -- we've written our default rule!

But let's finish it off by using it in an interesting way:: let's write something that looks like jax.value_and_grad, except that it'll pass in one of our tagged types at the start of the backward pass.

def tagged_value_and_grad(fn):
    @ft.wraps(fn)
    def fn_wrapped(arg, *args, **kwargs):
        fn_all_args_except_first = lambda x: fn(x, *args, **kwargs)
        out, fn_vjp = jax.vjp(fn_all_args_except_first, arg)
        if not eqx.is_array_like(out) or jnp.shape(out) != ():
            raise ValueError(
                "Wrapped function must return a scalar, just like `jax.grad`."
            )
        # The interesting bit! We quaxify the backward pass.
        (grad,) = quax.quaxify(fn_vjp)(BackwardTag(1.0))
        unwrap_tag = lambda x: x.array if isinstance(x, BackwardTag) else x
        grad = jtu.tree_map(
            unwrap_tag, grad, is_leaf=lambda x: isinstance(x, BackwardTag)
        )
        return out, grad

    return fn_wrapped

Now, as implemented... this hasn't actually done anything. Our backward pass uses tagged types, but then we unwrapped them from the gradients at the end. Why bother?

Time for the useful bit: by introducing a custom Quax rule, we can introduce custom behaviour only for operations that occur on the backward pass.

For this simple example, we're just going to have a print statement for all matmuls we encounter on the backward pass.

@quax.register(jax.lax.dot_general_p)
def _(lhs: BackwardTag, rhs: ArrayLike, **params):
    print("Performing a matmul with the tagged value on the LHS!")
    array = jax.lax.dot_general_p.bind(lhs.array, rhs, **params)
    return BackwardTag(array)


@quax.register(jax.lax.dot_general_p)
def _(lhs: ArrayLike, rhs: BackwardTag, **params):
    print("Performing a matmul with the tagged value on the RHS!")
    array = jax.lax.dot_general_p.bind(lhs, rhs.array, **params)
    return BackwardTag(array)


@quax.register(jax.lax.dot_general_p)
def _(lhs: BackwardTag, rhs: BackwardTag, **params):
    print("Performing a matmul with the tagged value on both sides!")
    array = jax.lax.dot_general_p.bind(lhs.array, rhs.array, **params)
    return BackwardTag(array)

And now here's a quick demonstration!

mlp = eqx.nn.MLP(in_size=2, out_size="scalar", width_size=32, depth=3, key=jr.key(0))
print(tagged_value_and_grad(mlp)(jnp.array([3.0, 4.0])))
Performing a matmul with the tagged value on the LHS!
Performing a matmul with the tagged value on the LHS!
Performing a matmul with the tagged value on the LHS!
Performing a matmul with the tagged value on the LHS!
(Array(0.08791415, dtype=float32), Array([ 0.0703299 , -0.04381455], dtype=float32))

The key take-aways from this example are:

  • How to use the default method.
  • How to use Quax simply as a tool that's hidden from an end user. (In this case to adjust the behaviour of a backward pass.)