# Introductory tutorial - custom rules - Unitful¤

In this example, we'll see how to create a custom array-ish Quax type.

We're going to implement a "unitful" type, which annotates each array with a unit like "length in meters" or "time in seconds". It will keep track of the units as they propagate through the computation, and disallow things like adding a length-array to a time-array. (Which isn't a thing you can do in physics!)

from typing import Union

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike  # https://github.com/patrick-kidger/quax

import quax


As a first step for this example (unrelated to Quax), let's define a toy unit system. (In this simple system we only have "meters" etc., but no notion of "kilometers"/"miles" etc.)

class Dimension:
def __init__(self, name):
self.name = name

def __repr__(self):
return self.name

kilograms = Dimension("kg")
meters = Dimension("m")
seconds = Dimension("s")

def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:
if isinstance(x, Dimension):
return {x: 1}
else:
return x


Now let's define our custom Quax type. It'll wrap together an array and a unit.

class Unitful(quax.ArrayValue):
array: ArrayLike
units: dict[Dimension, int] = eqx.field(static=True, converter=_dim_to_unit)

def aval(self):
shape = jnp.shape(self.array)
dtype = jnp.result_type(self.array)
return jax.core.ShapedArray(shape, dtype)

def materialise(self):
raise ValueError("Refusing to materialise Unitful array.")


Example usage for this is Unitful(array, meters) to indicate that the array has units of meters, or Unitful(array, {meters: 1, seconds: -1}) to indicate the array has units of meters-per-second.

Now let's define a few rules for how unitful arrays interact with each other.

@quax.register(jax.lax.add_p)
def _(x: Unitful, y: Unitful):  # function name doesn't matter
if x.units == y.units:
return Unitful(x.array + y.array, x.units)
else:
raise ValueError(f"Cannot add two arrays with units {x.units} and {y.units}.")

@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: Unitful):
units = x.units.copy()
for k, v in y.units.items():
if k in units:
units[k] += v
else:
units[k] = v
return Unitful(x.array * y.array, units)

@quax.register(jax.lax.mul_p)
def _(x: ArrayLike, y: Unitful):
return Unitful(x * y.array, y.units)

@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: ArrayLike):
return Unitful(x.array * y, x.units)

@quax.register(jax.lax.integer_pow_p)
def _(x: Unitful, *, y: int):
units = {k: v * y for k, v in x.units.items()}
return Unitful(x.array, units)


And now let's go ahead and use these in practice!

As our example, we'll consider computing the energy of a ball moving in Earth's gravity.

def kinetic_energy(mass, velocity):
"""Kinetic energy of a ball with mass moving with velocity."""
return 0.5 * mass * velocity**2

def gravitational_potential_energy(mass, height, g):
"""Gravitional potential energy of a ball with mass at a distance height above
the Earth's surface.
"""
return g * mass * height

def compute_energy(mass, velocity, height, g):
return kinetic_energy(mass, velocity) + gravitational_potential_energy(
mass, height, g
)

m = Unitful(jnp.array(3.0), kilograms)
v = Unitful(jnp.array(2.2), {meters: 1, seconds: -1})
h = Unitful(jnp.array(1.0), meters)
# acceleration due to Earth's gravity.
g = Unitful(jnp.array(9.81), {meters: 1, seconds: -2})
E = quax.quaxify(compute_energy)(m, v, h, g)
print(f"The amount of energy is {E.array.item()} with units {E.units}.")

The amount of energy is 32.72999954223633 with units {g: 1, m: 2, s: -2}.



Wonderful! That went perfectly.

The key take-aways from this example are:

• The basic usage of defining a custom type with its aval and materialise
• How to define a rule that binds your custom type against itself, e.g.
@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: Unitful): ...

• How to define a rule that binds your custom type against a normal JAX arraylike type, e.g.
@quax.register(jax.lax.mul_p)
def _(x: ArrayLike, y: Unitful): ...

(An ArrayLike is all the things JAX is normally willing to have interact with arrays: bool/int/float/complex/NumPy scalars/NumPy arrays/JAX arrays. You can think of the purpose of Quax as being a way to extend what it means for an object to be arraylike.)

## Mistakes we didn't make¤

Now let's look at all the ways we could have gotten an error by doing things wrong.

# Bad example 1: a unit mismatch

return x + y

try:
# This will throw an error because the addition rule only allows adding arrays with
# the same units as each other.
except ValueError as e:
print(f"Example 1 raises error {repr(e)}")

# Bad example 2: trying to add a normal JAX array onto a Unitful quantity:

try:
# This will throw an error because there's no rule for adding a Unitful array to
# a normal JAX array -- that is, we haven't defined a rule like
# 
# def _(x: Unitful, y: ArrayLike):
#     ...
# 
# so Quax tries to materialise the Unitful array into a normal JAX array. However,
# we've explicitly made materialise raise an error.
except ValueError as e:
print(f"Example 2 raises error {repr(e)}")

# Bad example 3: trying to create a Unitful type *without* passing it across a
# quax.quaxify boundary.

def unquaxed_example(x):
return Unitful(1, meters) * x

try:
# This will throw an error because there is (deliberately) not __mul__ method on
# Unitful. The pattern is that we (a) create a Quax type, and then (b) pass it
# across a quax.quaxify boundary. Whilst we *could* define a __mul__ method, it
# might dangerously start encouraging us to use Quax in a way it isn't designed
# for.
quax.quaxify(unquaxed_example)(10)
except TypeError as e:
print(f"Example 3 raises error {repr(e)}")

# Bad example 4: trying to create a Unitful type *without* passing it across a
# quax.quaxify boundary (again!).

def another_unquaxed_example(x):
return jax.lax.mul(Unitful(1, meters), x)

try:
# This will throw an error because Quax will attempt to bind Unitful directly,
# without it having passed across a quaxify boundary and being wrapped into a
# Quax tracer.
# As this is a common mistake, we have a special long-winded error message.
quax.quaxify(another_unquaxed_example)(Unitful(10, meters))
except TypeError as e:
print(f"\nExample 4 raises {type(e).__name__} with message:\n{e}")

Example 1 raises error ValueError('Cannot add two arrays with units {g: 1} and {m: 1, s: -1}.')
Example 2 raises error ValueError('Refusing to materialise Unitful array.')
Example 3 raises error TypeError("unsupported operand type(s) for *: 'Unitful' and 'int'")

Example 4 raises TypeError with message:
Encountered Quax value of type <class '__main__.Unitful'>. These must be transformed by passing them across a quax.quaxify boundary before being used.
For example, the following is incorrect, as SomeValue() is not explicitly passed across the API boundary:
def f(x):
return x + SomeValue()

quax.quaxify(f)(AnotherValue())

This should instead be written as the following:
explicitly passed across the API boundary:
def f(x, y):
return x + y

quax.quaxify(f)(AnotherValue(), SomeValue())

To better understand this, remember that the purpose of Quax is take a JAX program (given as a function) that acts on arrays, and to instead run it with array-ish types. But in the first example above, the original program already has an array-ish type, even before the quaxify is introduced.



The key take-away from this set of failures is how you must create a Quax type, and then immediately pass it across a quax.quaxify boundary. You can't create them once you're inside! (After all, what if you had two nested quax.quaxify calls -- how should we use a Quax type that's created inside? We could associate it with either transform.)