Skip to content

quax¤

An end user of a library built on Quax needs only one thing from this section: the quax.quaxify function.

quax.quaxify(fn, filter_spec = True) ¤

'Quaxifies' a function, so that it understands custom array-ish objects like quax.examples.lora.LoraArray. When this function is called, multiple dispatch will be performed against the types it is called with.

Arguments:

  • fn: the function to wrap.
  • filter_spec: which arguments to quaxify. Advanced usage, see tip below.

Returns:

A copy of fn, that understands all Quax types.

Only quaxifying some argments

Calling quax.quaxify(fn, filter_spec)(*args, **kwargs) will under-the-hood run dynamic, static = eqx.partition((fn, args, kwargs), filter_spec), and then only quaxify those arguments in dynamic. This allows for passing through some quax.Values into the function unchanged, typically so that they can hit a nested quax.quaxify. See the advanced tutorial.


A developer of a library built on Quax (e.g. if you wanted to write your own libary analogous to quax.examples.lora) should additionally know about the following functionality.

Info

See also the tutorials for creating your own array-ish Quax types.

quax.register(primitive: Primitive) ¤

Registers a multiple dispatch implementation for this JAX primitive.

Example

Used as decorator, and requires type annotations to perform multiple dispatch:

@quax.register(jax.lax.add_p)
def _(x: SomeValue, y: SomeValue):
    return ...  # some implementation

All positional arguments will be (subclasses of) quax.Value -- these are the set of types that Quax will attempt to perform multiple dispatch with.

All keyword arguments will be the parameters for this primitive, as passed to prim.bind(... **params).

Arguments:

  • primitive: The jax.core.Primitive to provide a multiple dispatch implementation for.

Returns:

A decorator for registering a multiple dispatch rule with the specified primitive.

quax.Value ¤

Represents an object which Quax can perform multiple dispatch with.

In practice you will almost always want to inherit from quax.ArrayValue instead, which represents specifically an array-ish object that can be used for multiple dispatch.

aval(self) -> AbstractValue abstractmethod ¤

All concrete subclasses must implement this method, specifying the abstract value seen by JAX.

Arguments:

Nothing.

Returns:

Any subclass of jax.core.AbstractValue. Typically a jax.core.ShapedArray.

default(primitive, values: Sequence[Union[ArrayLike, Value]], params) -> Union[ArrayLike, Value, Sequence[Union[ArrayLike, Value]]] staticmethod ¤

This is the default rule for when no rule has been quax.register'd for a primitive.

When performing multiple dispatch primitive.bind(value1, value2, value3), then:

  1. If there is a dispatch rule matching the types of value1, value2, and value3, then that will be used.
  2. If precisely one of the types of value{1,2,3} overloads this method, then that default rule will be used.
  3. If precisely zero of the types of value{1,2,3} overloads this method, then all values are quax.Value.materialised, and the usual JAX implementation is called.
  4. If multiple of the types of value{1,2,3} overload this method, then a trace-time error will be raised.

Arguments:

  • primitive: the jax.core.Primitive being considered.
  • values: a sequence of what values this primitive is being called with. Each value can either be quax.Values, or a normal JAX arraylike (i.e. bool/int/float/complex/NumPy scalar/NumPy array/JAX array).
  • params: the keyword parameters to the primitive.

Returns:

The result of binding this primitive against these types. If primitive.multiple_results is False then this should be a single quax.Value or JAX arraylike. If primitive.multiple_results is True, then this should be a tuple/list of such values.

Example

The default implementation discussed above performs the following:

@staticmethod
def default(primitive, values, params):
    arrays = [x if equinox.is_array_like(x) else x.materialise()
              for x in values]
    return primitive.bind(*arrays, **params)
(Using the Equinox library that underlies much of the JAX ecosystem.)

materialise(self) -> Any abstractmethod ¤

All concrete subclasses must implement this method, specifying how to materialise this object into a JAX type (i.e. almost always a JAX array, unless you're doing something obscure using tokens or refs).

Example

For example, a LoRA array consists of three arrays (W, A, B), combined as W + AB. [quax.examples.lora.LoraArray] leaves these as three separate arrays for efficiency, but calling lora_array.materialise() will evaluate W + AB and return a normal JAX array.

This is so that the usual JAX primitive implementations can be applied as a fallback: the array-ish object is materialised, and then the usual JAX implementation called on it. (See quax.Value.default.)

Info

It is acceptable for this function to just raise an error -- in this case the error will be surfaced to the end user, indicating that an operation is not supported for this array-ish object.

Arguments:

Nothing.

Returns:

A JAX type; typically a JAX array.

quax.ArrayValue (Value) ¤

A subclass quax.Value for specifically array-like types. If you are creating a custom array-ish object then you should typically inherit from this.

Provides the properties .shape, .dtype, .ndim, .size, each as a shortcut for self.aval().shape etc.