Extra features¤
Converters and static fields¤
Equinox modules are dataclasses. Equinox extends this support with converters and static fields.
equinox.field(*, converter: Callable[[Any], Any] = lambda x: x, static: bool = False, **kwargs)
¤
Equinox supports extra functionality on top of the default dataclasses.
Arguments:
converter
: a function to call on this field when the model is initialised. For example,field(converter=jax.numpy.asarray)
to convertbool
/int
/float
/complex
values to JAX arrays. This is ran after the__init__
method (i.e. when using a user-provided__init__
), and before__post_init__
(i.e. when using the default dataclass initialisation).static
: whether the field should not interact with any JAX transform at all (by making it part of the PyTree structure rather than a leaf).**kwargs
: All other keyword arguments are passed on todataclass.field
.
Example for converter
class MyModule(eqx.Module):
foo: Array = eqx.field(converter=jax.numpy.asarray)
mymodule = MyModule(1.0)
assert isinstance(mymodule.foo, jax.Array)
Example for static
class MyModule(eqx.Module):
normal_field: int
static_field: int = eqx.field(static=True)
mymodule = MyModule("normal", "static")
leaves, treedef = jax.tree_util.tree_flatten(mymodule)
assert leaves == ["normal"]
assert "static" in str(treedef)
static=True
means that this field is not a node of the PyTree, so it does not
interact with any JAX transforms, like JIT or grad. This means that it is usually a
bug to make JAX arrays be static fields. static=True
should very rarely be used.
It is preferred to just filter out each field with eqx.partition
whenever you need
to select only some fields.
Abstract attributes¤
Equinox modules can be used as abstract base classes, which means they support abc.abstractmethod
. Equinox extends this with support for abstract instance attributes and abstract class attributes.
equinox.AbstractVar
¤
Used to mark an abstract instance attribute, along with its type. Used as:
class Foo(eqx.Module):
attr: AbstractVar[bool]
An AbstractVar[T]
must be overridden by an attribute annotated with
AbstractVar[T]
, AbstractClassVar[T]
, ClassVar[T]
, T
, or a property
returning T
.
This makes AbstractVar
useful when you just want to assert that you can access
self.attr
on a subclass, regardless of whether it's an instance attribute,
class attribute, property, etc.
Attempting to instantiate a module with an unoveridden AbstractVar
will raise
an error.
Example
import equinox as eqx
class AbstractX(eqx.Module):
attr1: int
attr2: AbstractVar[bool]
class ConcreteX(AbstractX):
attr2: bool
ConcreteX(attr1=1, attr2=True)
Info
AbstractVar
does not create a dataclass field. This affects the order of
__init__
arguments. E.g.
class AbstractX(Module):
attr1: AbstractVar[bool]
class ConcreteX(AbstractX):
attr2: str
attr1: bool
ConcreteX(attr2, attr1)
.
equinox.AbstractClassVar
¤
Used to mark an abstract class attribute, along with its type. Used as:
class Foo(eqx.Module):
attr: AbstractClassVar[bool]
An AbstractClassVar[T]
can be overridden by an attribute annotated with
AbstractClassVar[T]
, or ClassVar[T]
. This makes AbstractClassVar
useful
when you want to assert that you can access cls.attr
on a subclass.
Attempting to instantiate a module with an unoveridden AbstractClassVar
will
raise an error.
Example
import equinox as eqx
from typing import ClassVar
class AbstractX(eqx.Module):
attr1: int
attr2: AbstractClassVar[bool]
class ConcreteX(AbstractX):
attr2: ClassVar[bool] = True
ConcreteX(attr1=1)
Info
AbstractClassVar
does not create a dataclass field. This affects the order
of __init__
arguments. E.g.
class AbstractX(Module):
attr1: AbstractClassVar[bool]
class ConcreteX(AbstractX):
attr2: str
attr1: ClassVar[bool] = True
ConcreteX(attr2)
.
Known issues¤
Due to a Pyright bug (#4965), this must be imported as:
if TYPE_CHECKING:
from typing import ClassVar as AbstractClassVar
else:
from equinox import AbstractClassVar
Checking invariants¤
Equinox extends dataclasses with a __check_init__
method, which is automatically ran after initialisation. This can be used to check invariants like so:
class Positive(eqx.Module):
x: int
def __check_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")
This method has three key differences compared to the __post_init__
provided by dataclasses:
-
It is not overridden by an
__init__
method of a subclass. In contrast, the following code has a bug (Equinox will raise a warning if you do this):class Parent(eqx.Module): x: int def __post_init__(self): if self.x <= 0: raise ValueError("Oh no!") class Child(Parent): x_as_str: str def __init__(self, x): self.x = x self.x_as_str = str(x) Child(-1) # No error!
-
It is automatically called for parent classes;
super().__check_init__()
is not required:class Parent(eqx.Module): def __check_init__(self): print("Parent") class Child(Parent): def __check_init__(self): print("Child") Child() # prints out both Child and Parent
As with the previous bullet point, this is to prevent child classes accidentally failing to check that the invariants of their parent hold.
-
Assignment is not allowed:
class MyModule(eqx.Module): foo: int def __check_init__(self): self.foo = 1 # will raise an error
This is to prevent
__check_init__
from doing anything too surprising: as the name suggests, it's meant to be used for checking invariants.
Creating wrapper modules¤
equinox.module_update_wrapper(wrapper: Module, wrapped: Optional[Callable[~_P, ~_T]] = None) -> Callable[~_P, ~_T]
¤
Like functools.update_wrapper
(or its better-known cousin, functools.wraps
),
but acts on equinox.Module
s, and does not modify its input (it returns the
updated module instead).
Example
class Wrapper(eqx.Module):
fn: Callable
def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs)
@property
def __wrapped__(self):
return self.fn
def make_wrapper(fn):
return eqx.module_update_wrapper(Wrapper(fn))
For example, equinox.filter_jit
returns a module representing the JIT'd
computation. module_update_wrapper
is used on this module to indicate that this
JIT'd computation wraps the original one. (Just like how functools.wraps
is used.)
Note that as in the above example, the wrapper class must supply a __wrapped__
property, which redirects to the wrapped object.
Arguments:
wrapper
: the instance of the wrapper.wrapped
: optional, the callable that is being wrapped. If omitted thenwrapper.__wrapped__
will be used.
Returns:
A copy of wrapper
, with the attributes __module__
, __name__
, __qualname__
,
__doc__
, and __annotations__
copied over from the wrapped function.