Introductory tutorial - custom rules - Unitful¤
In this example, we'll see how to create a custom array-ish Quax type.
We're going to implement a "unitful" type, which annotates each array with a unit like "length in meters" or "time in seconds". It will keep track of the units as they propagate through the computation, and disallow things like adding a length-array to a time-array. (Which isn't a thing you can do in physics!)
from typing import Union
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike # https://github.com/patrick-kidger/quax
import quax
As a first step for this example (unrelated to Quax), let's define a toy unit system. (In this simple system we only have "meters" etc., but no notion of "kilometers"/"miles" etc.)
class Dimension:
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
kilograms = Dimension("kg")
meters = Dimension("m")
seconds = Dimension("s")
def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:
if isinstance(x, Dimension):
return {x: 1}
else:
return x
Now let's define our custom Quax type. It'll wrap together an array and a unit.
class Unitful(quax.ArrayValue):
array: ArrayLike
units: dict[Dimension, int] = eqx.field(static=True, converter=_dim_to_unit)
def aval(self):
shape = jnp.shape(self.array)
dtype = jnp.result_type(self.array)
return jax.core.ShapedArray(shape, dtype)
def materialise(self):
raise ValueError("Refusing to materialise Unitful array.")
Example usage for this is Unitful(array, meters)
to indicate that the array has units of meters, or Unitful(array, {meters: 1, seconds: -1})
to indicate the array has units of meters-per-second.
Now let's define a few rules for how unitful arrays interact with each other.
@quax.register(jax.lax.add_p)
def _(x: Unitful, y: Unitful): # function name doesn't matter
if x.units == y.units:
return Unitful(x.array + y.array, x.units)
else:
raise ValueError(f"Cannot add two arrays with units {x.units} and {y.units}.")
@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: Unitful):
units = x.units.copy()
for k, v in y.units.items():
if k in units:
units[k] += v
else:
units[k] = v
return Unitful(x.array * y.array, units)
@quax.register(jax.lax.mul_p)
def _(x: ArrayLike, y: Unitful):
return Unitful(x * y.array, y.units)
@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: ArrayLike):
return Unitful(x.array * y, x.units)
@quax.register(jax.lax.integer_pow_p)
def _(x: Unitful, *, y: int):
units = {k: v * y for k, v in x.units.items()}
return Unitful(x.array, units)
And now let's go ahead and use these in practice!
As our example, we'll consider computing the energy of a ball moving in Earth's gravity.
def kinetic_energy(mass, velocity):
"""Kinetic energy of a ball with `mass` moving with `velocity`."""
return 0.5 * mass * velocity**2
def gravitational_potential_energy(mass, height, g):
"""Gravitional potential energy of a ball with `mass` at a distance `height` above
the Earth's surface.
"""
return g * mass * height
def compute_energy(mass, velocity, height, g):
return kinetic_energy(mass, velocity) + gravitational_potential_energy(
mass, height, g
)
m = Unitful(jnp.array(3.0), kilograms)
v = Unitful(jnp.array(2.2), {meters: 1, seconds: -1})
h = Unitful(jnp.array(1.0), meters)
# acceleration due to Earth's gravity.
g = Unitful(jnp.array(9.81), {meters: 1, seconds: -2})
E = quax.quaxify(compute_energy)(m, v, h, g)
print(f"The amount of energy is {E.array.item()} with units {E.units}.")
Wonderful! That went perfectly.
The key take-aways from this example are:
- The basic usage of defining a custom type with its
aval
andmaterialise
- How to define a rule that binds your custom type against itself, e.g.
@quax.register(jax.lax.mul_p) def _(x: Unitful, y: Unitful): ...
- How to define a rule that binds your custom type against a normal JAX arraylike type, e.g.
(An
@quax.register(jax.lax.mul_p) def _(x: ArrayLike, y: Unitful): ...
ArrayLike
is all the things JAX is normally willing to have interact with arrays:bool
/int
/float
/complex
/NumPy scalars/NumPy arrays/JAX arrays. You can think of the purpose of Quax as being a way to extend what it means for an object to be arraylike.)
Mistakes we didn't make¤
Now let's look at all the ways we could have gotten an error by doing things wrong.
# Bad example 1: a unit mismatch
def bad_physics(x, y):
return x + y
try:
# This will throw an error because the addition rule only allows adding arrays with
# the same units as each other.
quax.quaxify(bad_physics)(m, v)
except ValueError as e:
print(f"Example 1 raises error {repr(e)}")
# Bad example 2: trying to add a normal JAX array onto a Unitful quantity:
try:
# This will throw an error because there's no rule for adding a Unitful array to
# a normal JAX array -- that is, we haven't defined a rule like
# ```
# @quax.register(jax.lax.add_p)
# def _(x: Unitful, y: ArrayLike):
# ...
# ```
# so Quax tries to materialise the Unitful array into a normal JAX array. However,
# we've explicitly made `materialise` raise an error.
quax.quaxify(bad_physics)(m, jnp.array(0))
except ValueError as e:
print(f"Example 2 raises error {repr(e)}")
# Bad example 3: trying to create a `Unitful` type *without* passing it across a
# `quax.quaxify` boundary.
def unquaxed_example(x):
return Unitful(1, meters) * x
try:
# This will throw an error because there is (deliberately) not `__mul__` method on
# `Unitful`. The pattern is that we (a) create a Quax type, and then (b) pass it
# across a `quax.quaxify` boundary. Whilst we *could* define a `__mul__` method, it
# might dangerously start encouraging us to use Quax in a way it isn't designed
# for.
quax.quaxify(unquaxed_example)(10)
except TypeError as e:
print(f"Example 3 raises error {repr(e)}")
# Bad example 4: trying to create a `Unitful` type *without* passing it across a
# `quax.quaxify` boundary (again!).
def another_unquaxed_example(x):
return jax.lax.mul(Unitful(1, meters), x)
try:
# This will throw an error because Quax will attempt to bind `Unitful` directly,
# without it having passed across a `quaxify` boundary and being wrapped into a
# Quax tracer.
# As this is a common mistake, we have a special long-winded error message.
quax.quaxify(another_unquaxed_example)(Unitful(10, meters))
except TypeError as e:
print(f"\nExample 4 raises {type(e).__name__} with message:\n{e}")
The key take-away from this set of failures is how you must create a Quax type, and then immediately pass it across a quax.quaxify
boundary. You can't create them once you're inside! (After all, what if you had two nested quax.quaxify
calls -- how should we use a Quax type that's created inside? We could associate it with either transform.)