quax¤
An end user of a library built on Quax needs only one thing from this section: the quax.quaxify
function.
quax.quaxify(fn: ~CT, filter_spec: PyTree[bool | Callable[[Any], bool]] = True) -> ~CT
¤
'Quaxifies' a function, so that it understands custom array-ish objects like
quax.examples.lora.LoraArray
. When this function is called, multiple dispatch
will be performed against the types it is called with.
Arguments:
fn
: the function to wrap.filter_spec
: which arguments to quaxify. Advanced usage, see tip below.
Returns:
A copy of fn
, that understands all Quax types.
Only quaxifying some argments
Calling quax.quaxify(fn, filter_spec)(*args, **kwargs)
will under-the-hood run
dynamic, static = eqx.partition((fn, args, kwargs), filter_spec)
, and then
only quaxify those arguments in dynamic
. This allows for passing through some
quax.Value
s into the function unchanged, typically so that they can hit a
nested quax.quaxify
. See the
advanced tutorial.
A developer of a library built on Quax (e.g. if you wanted to write your own libary analogous to quax.examples.lora
) should additionally know about the following functionality.
Info
See also the tutorials for creating your own array-ish Quax types.
quax.register(primitive: jax.extend.core.Primitive, *, precedence: int = 0) -> Callable[[~CT], ~CT]
¤
Registers a multiple dispatch implementation for this JAX primitive.
Example
Used as decorator, and requires type annotations to perform multiple dispatch:
@quax.register(jax.lax.add_p)
def _(x: SomeValue, y: SomeValue):
return ... # some implementation
All positional arguments will be (subclasses of) quax.Value
-- these are the
set of types that Quax will attempt to perform multiple dispatch with.
All keyword arguments will be the parameters for this primitive, as passed to
prim.bind(... **params)
.
Arguments:
-
primitive
: Thejax.extend.core.Primitive
to provide a multiple dispatch implementation for. -
precedence
: The precedence of this rule. Seeplum.Dispatcher.dispatch
for details.
Returns:
A decorator for registering a multiple dispatch rule with the specified primitive.
quax.Value
¤
Represents an object which Quax can perform multiple dispatch with.
In practice you will almost always want to inherit from quax.ArrayValue
instead, which represents specifically an array-ish object that can be used for
multiple dispatch.
aval() -> jax.core.AbstractValue
¤
All concrete subclasses must implement this method, specifying the abstract value seen by JAX.
Arguments:
Nothing.
Returns:
Any subclass of jax.core.AbstractValue
. Typically a jax.core.ShapedArray
.
default(primitive: Primitive, values: collections.abc.Sequence[typing.Union[ArrayLike, ForwardRef(Value)]], params) -> typing.Union[ArrayLike, ForwardRef(Value), collections.abc.Sequence[typing.Union[ArrayLike, ForwardRef(Value)]]]
staticmethod
¤
This is the default rule for when no rule has been quax.register
'd for
a primitive.
When performing multiple dispatch primitive.bind(value1, value2, value3)
,
then:
- If there is a dispatch rule matching the types of
value1
,value2
, andvalue3
, then that will be used. - If precisely one of the types of
value{1,2,3}
overloads this method, then that default rule will be used. - If precisely zero of the types of
value{1,2,3}
overloads this method, then all values arequax.Value.materialise
d, and the usual JAX implementation is called. - If multiple of the types of
value{1,2,3}
overload this method, then a trace-time error will be raised.
Arguments:
primitive
: thejax.extend.core.Primitive
being considered.values
: a sequence of what values this primitive is being called with. Each value can either bequax.Value
s, or a normal JAX arraylike (i.e.bool
/int
/float
/complex
/NumPy scalar/NumPy array/JAX array).params
: the keyword parameters to the primitive.
Returns:
The result of binding this primitive against these types. If
primitive.multiple_results is False
then this should be a single quax.Value
or JAX arraylike. If primitive.multiple_results is True
, then this should be
a tuple/list of such values.
Example
The default implementation discussed above performs the following:
@staticmethod
def default(primitive, values, params):
arrays = [x if equinox.is_array_like(x) else x.materialise()
for x in values]
return primitive.bind(*arrays, **params)
materialise() -> Any
¤
All concrete subclasses must implement this method, specifying how to materialise this object into a JAX type (i.e. almost always a JAX array, unless you're doing something obscure using tokens or refs).
Example
For example, a LoRA array consists of three arrays (W, A, B)
, combined as
W + AB
. [quax.examples.lora.LoraArray
] leaves these as three separate
arrays for efficiency, but calling lora_array.materialise()
will evaluate
W + AB
and return a normal JAX array.
This is so that the usual JAX primitive implementations can be applied as a
fallback: the array-ish object is materialised, and then the usual JAX
implementation called on it. (See quax.Value.default
.)
Info
It is acceptable for this function to just raise an error -- in this case the error will be surfaced to the end user, indicating that an operation is not supported for this array-ish object.
Arguments:
Nothing.
Returns:
A JAX type; typically a JAX array.
quax.ArrayValue(quax.Value)
¤
A subclass quax.Value
for specifically array-like types. If you are
creating a custom array-ish object then you should typically inherit from this.
Provides the properties .shape
, .dtype
, .ndim
, .size
, each as a shortcut for
self.aval().shape
etc.