Skip to content

Introductory tutorial - custom rules - Unitful¤

In this example, we'll see how to create a custom array-ish Quax type.

We're going to implement a "unitful" type, which annotates each array with a unit like "length in meters" or "time in seconds". It will keep track of the units as they propagate through the computation, and disallow things like adding a length-array to a time-array. (Which isn't a thing you can do in physics!)

from typing import Union

import equinox as eqx  #
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike  #

import quax

As a first step for this example (unrelated to Quax), let's define a toy unit system. (In this simple system we only have "meters" etc., but no notion of "kilometers"/"miles" etc.)

class Dimension:
    def __init__(self, name): = name

    def __repr__(self):

kilograms = Dimension("kg")
meters = Dimension("m")
seconds = Dimension("s")

def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:
    if isinstance(x, Dimension):
        return {x: 1}
        return x

Now let's define our custom Quax type. It'll wrap together an array and a unit.

class Unitful(quax.ArrayValue):
    array: ArrayLike
    units: dict[Dimension, int] = eqx.field(static=True, converter=_dim_to_unit)

    def aval(self):
        shape = jnp.shape(self.array)
        dtype = jnp.result_type(self.array)
        return jax.core.ShapedArray(shape, dtype)

    def materialise(self):
        raise ValueError("Refusing to materialise Unitful array.")

Example usage for this is Unitful(array, meters) to indicate that the array has units of meters, or Unitful(array, {meters: 1, seconds: -1}) to indicate the array has units of meters-per-second.

Now let's define a few rules for how unitful arrays interact with each other.

def _(x: Unitful, y: Unitful):  # function name doesn't matter
    if x.units == y.units:
        return Unitful(x.array + y.array, x.units)
        raise ValueError(f"Cannot add two arrays with units {x.units} and {y.units}.")

def _(x: Unitful, y: Unitful):
    units = x.units.copy()
    for k, v in y.units.items():
        if k in units:
            units[k] += v
            units[k] = v
    return Unitful(x.array * y.array, units)

def _(x: ArrayLike, y: Unitful):
    return Unitful(x * y.array, y.units)

def _(x: Unitful, y: ArrayLike):
    return Unitful(x.array * y, x.units)

def _(x: Unitful, *, y: int):
    units = {k: v * y for k, v in x.units.items()}
    return Unitful(x.array, units)

And now let's go ahead and use these in practice!

As our example, we'll consider computing the energy of a ball moving in Earth's gravity.

def kinetic_energy(mass, velocity):
    """Kinetic energy of a ball with `mass` moving with `velocity`."""
    return 0.5 * mass * velocity**2

def gravitational_potential_energy(mass, height, g):
    """Gravitional potential energy of a ball with `mass` at a distance `height` above
    the Earth's surface.
    return g * mass * height

def compute_energy(mass, velocity, height, g):
    return kinetic_energy(mass, velocity) + gravitational_potential_energy(
        mass, height, g

m = Unitful(jnp.array(3.0), kilograms)
v = Unitful(jnp.array(2.2), {meters: 1, seconds: -1})
h = Unitful(jnp.array(1.0), meters)
# acceleration due to Earth's gravity.
g = Unitful(jnp.array(9.81), {meters: 1, seconds: -2})
E = quax.quaxify(compute_energy)(m, v, h, g)
print(f"The amount of energy is {E.array.item()} with units {E.units}.")
The amount of energy is 32.72999954223633 with units {g: 1, m: 2, s: -2}.

Wonderful! That went perfectly.

The key take-aways from this example are:

  • The basic usage of defining a custom type with its aval and materialise
  • How to define a rule that binds your custom type against itself, e.g.
    def _(x: Unitful, y: Unitful): ...
  • How to define a rule that binds your custom type against a normal JAX arraylike type, e.g.
    def _(x: ArrayLike, y: Unitful): ...
    (An ArrayLike is all the things JAX is normally willing to have interact with arrays: bool/int/float/complex/NumPy scalars/NumPy arrays/JAX arrays. You can think of the purpose of Quax as being a way to extend what it means for an object to be arraylike.)

Mistakes we didn't make¤

Now let's look at all the ways we could have gotten an error by doing things wrong.

# Bad example 1: a unit mismatch

def bad_physics(x, y):
    return x + y

    # This will throw an error because the addition rule only allows adding arrays with
    # the same units as each other.
    quax.quaxify(bad_physics)(m, v)
except ValueError as e:
    print(f"Example 1 raises error {repr(e)}")

# Bad example 2: trying to add a normal JAX array onto a Unitful quantity:

    # This will throw an error because there's no rule for adding a Unitful array to
    # a normal JAX array -- that is, we haven't defined a rule like
    # ```
    # @quax.register(jax.lax.add_p)
    # def _(x: Unitful, y: ArrayLike):
    #     ...
    # ```
    # so Quax tries to materialise the Unitful array into a normal JAX array. However,
    # we've explicitly made `materialise` raise an error.
    quax.quaxify(bad_physics)(m, jnp.array(0))
except ValueError as e:
    print(f"Example 2 raises error {repr(e)}")

# Bad example 3: trying to create a `Unitful` type *without* passing it across a
# `quax.quaxify` boundary.

def unquaxed_example(x):
    return Unitful(1, meters) * x

    # This will throw an error because there is (deliberately) not `__mul__` method on
    # `Unitful`. The pattern is that we (a) create a Quax type, and then (b) pass it
    # across a `quax.quaxify` boundary. Whilst we *could* define a `__mul__` method, it
    # might dangerously start encouraging us to use Quax in a way it isn't designed
    # for.
except TypeError as e:
    print(f"Example 3 raises error {repr(e)}")

# Bad example 4: trying to create a `Unitful` type *without* passing it across a
# `quax.quaxify` boundary (again!).

def another_unquaxed_example(x):
    return jax.lax.mul(Unitful(1, meters), x)

    # This will throw an error because Quax will attempt to bind `Unitful` directly,
    # without it having passed across a `quaxify` boundary and being wrapped into a
    # Quax tracer.
    # As this is a common mistake, we have a special long-winded error message.
    quax.quaxify(another_unquaxed_example)(Unitful(10, meters))
except TypeError as e:
    print(f"\nExample 4 raises {type(e).__name__} with message:\n{e}")
Example 1 raises error ValueError('Cannot add two arrays with units {g: 1} and {m: 1, s: -1}.')
Example 2 raises error ValueError('Refusing to materialise Unitful array.')
Example 3 raises error TypeError("unsupported operand type(s) for *: 'Unitful' and 'int'")

Example 4 raises TypeError with message:
Encountered Quax value of type <class '__main__.Unitful'>. These must be transformed by passing them across a `quax.quaxify` boundary before being used.
For example, the following is incorrect, as `SomeValue()` is not explicitly passed across the API boundary:
def f(x):
    return x + SomeValue()

This should instead be written as the following: explicitly passed across the API boundary:
def f(x, y):
    return x + y

quax.quaxify(f)(AnotherValue(), SomeValue())
To better understand this, remember that the purpose of Quax is take a JAX program (given as a function) that acts on arrays, and to instead run it with array-ish types. But in the first example above, the original program already has an array-ish type, even before the `quaxify` is introduced.

The key take-away from this set of failures is how you must create a Quax type, and then immediately pass it across a quax.quaxify boundary. You can't create them once you're inside! (After all, what if you had two nested quax.quaxify calls -- how should we use a Quax type that's created inside? We could associate it with either transform.)