Skip to content

Extra features¤

Converters and static fields¤

Equinox modules are dataclasses. Equinox extends this support with converters and static fields.

equinox.field(*, converter: Callable[[Any], Any] = lambda x: x, static: bool = False, **kwargs) ¤

Equinox supports extra functionality on top of the default dataclasses.

Arguments:

  • converter: a function to call on this field when the model is initialised. For example, field(converter=jax.numpy.asarray) to convert bool/int/float/complex values to JAX arrays. This is ran after the __init__ method (i.e. when using a user-provided __init__), and before __post_init__ (i.e. when using the default dataclass initialisation).
  • static: whether the field should not interact with any JAX transform at all (by making it part of the PyTree structure rather than a leaf).
  • **kwargs: All other keyword arguments are passed on to dataclass.field.

Example for converter

class MyModule(eqx.Module):
    foo: Array = eqx.field(converter=jax.numpy.asarray)

mymodule = MyModule(1.0)
assert isinstance(mymodule.foo, jax.Array)

Example for static

class MyModule(eqx.Module):
    normal_field: int
    static_field: int = eqx.field(static=True)

mymodule = MyModule("normal", "static")
leaves, treedef = jax.tree_util.tree_flatten(mymodule)
assert leaves == ["normal"]
assert "static" in str(treedef)

static=True means that this field is not a node of the PyTree, so it does not interact with any JAX transforms, like JIT or grad. This means that it is usually a bug to make JAX arrays be static fields. static=True should very rarely be used. It is preferred to just filter out each field with eqx.partition whenever you need to select only some fields.

Abstract attributes¤

Equinox modules can be used as abstract base classes, which means they support abc.abstractmethod. Equinox extends this with support for abstract instance attributes and abstract class attributes.

equinox.AbstractVar ¤

Used to mark an abstract instance attribute, along with its type. Used as:

class Foo(eqx.Module):
    attr: AbstractVar[bool]

An AbstractVar[T] must be overridden by an attribute annotated with AbstractVar[T], AbstractClassVar[T], ClassVar[T], T, or a property returning T.

This makes AbstractVar useful when you just want to assert that you can access self.attr on a subclass, regardless of whether it's an instance attribute, class attribute, property, etc.

Attempting to instantiate a module with an unoveridden AbstractVar will raise an error.

Example

import equinox as eqx

class AbstractX(eqx.Module):
    attr1: int
    attr2: AbstractVar[bool]

class ConcreteX(AbstractX):
    attr2: bool

ConcreteX(attr1=1, attr2=True)

Info

AbstractVar does not create a dataclass field. This affects the order of __init__ argments. E.g.

class AbstractX(Module):
    attr1: AbstractVar[bool]

class ConcreteX(AbstractX):
    attr2: str
    attr1: bool
should be called as ConcreteX(attr2, attr1).

equinox.AbstractClassVar ¤

Used to mark an abstract class attribute, along with its type. Used as:

class Foo(eqx.Module):
    attr: AbstractClassVar[bool]

An AbstractClassVar[T] can be overridden by an attribute annotated with AbstractClassVar[T], or ClassVar[T]. This makes AbstractClassVar useful when you want to assert that you can access cls.attr on a subclass.

Attempting to instantiate a module with an unoveridden AbstractClassVar will raise an error.

Example

import equinox as eqx
from typing import ClassVar

class AbstractX(eqx.Module):
    attr1: int
    attr2: AbstractClassVar[bool]

class ConcreteX(AbstractX):
    attr2: ClassVar[bool] = True

ConcreteX(attr1=1)

Info

AbstractClassVar does not create a dataclass field. This affects the order of __init__ argments. E.g.

class AbstractX(Module):
    attr1: AbstractClassVar[bool]

class ConcreteX(AbstractX):
    attr2: str
    attr1: ClassVar[bool] = True
should be called as ConcreteX(attr2).

Known issues¤

Due to a Pyright bug (#4965), this must be imported as:

if TYPE_CHECKING:
    from typing import ClassVar as AbstractClassVar
else:
    from equinox import AbstractClassVar

Checking invariants¤

Equinox extends dataclasses with a __check_init__ method, which is automatically ran after initialisation. This can be used to check invariants like so:

class Positive(eqx.Module):
    x: int

    def __check_init__(self):
        if self.x <= 0:
            raise ValueError("Oh no!")

This method has three key differences compared to the __post_init__ provided by dataclasses:

  • It is not overridden by an __init__ method of a subclass. In contrast, the following code has a bug (Equinox will raise a warning if you do this):

    class Parent(eqx.Module):
        x: int
    
        def __post_init__(self):
            if self.x <= 0:
                raise ValueError("Oh no!")
    
    class Child(Parent):
        x_as_str: str
    
        def __init__(self, x):
            self.x = x
            self.x_as_str = str(x)
    
    Child(-1)  # No error!
    
  • It is automatically called for parent classes; super().__check_init__() is not required:

    class Parent(eqx.Module):
        def __check_init__(self):
            print("Parent")
    
    class Child(Parent):
        def __check_init__(self):
            print("Child")
    
    Child()  # prints out both Child and Parent
    

    As with the previous bullet point, this is to prevent child classes accidentally failing to check that the invariants of their parent hold.

  • Assignment is not allowed:

    class MyModule(eqx.Module):
        foo: int
    
        def __check_init__(self):
            self.foo = 1  # will raise an error
    

    This is to prevent __check_init__ from doing anything too surprising: as the name suggests, it's meant to be used for checking invariants.

Creating wrapper modules¤

equinox.module_update_wrapper(wrapper: Module, wrapped: Optional[Callable[~_P, ~_T]] = None) -> Callable[~_P, ~_T] ¤

Like functools.update_wrapper (or its better-known cousin, functools.wraps), but acts on equinox.Modules, and does not modify its input (it returns the updated module instead).

Example

class Wrapper(eqx.Module):
    fn: Callable

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    @property
    def __wrapped__(self):
        return self.fn

def make_wrapper(fn):
    return eqx.module_update_wrapper(Wrapper(fn))

For example, equinox.filter_jit returns a module representing the JIT'd computation. module_update_wrapper is used on this module to indicate that this JIT'd computation wraps the original one. (Just like how functools.wraps is used.)

Note that as in the above example, the wrapper class must supply a __wrapped__ property, which redirects to the wrapped object.

Arguments:

  • wrapper: the instance of the wrapper.
  • wrapped: optional, the callable that is being wrapped. If omitted then wrapper.__wrapped__ will be used.

Returns:

A copy of wrapper, with the attributes __module__, __name__, __qualname__, __doc__, and __annotations__ copied over from the wrapped function.