Skip to content

Module¤

equinox.Module ¤

Base class. Create your model by inheriting from this.

This will make your model a dataclass and a pytree.

Fields

Specify all its fields at the class level (identical to dataclasses). This defines its children as a PyTree.

class MyModule(equinox.Module):
    weight: jax.Array
    bias: jax.Array
    submodule: equinox.Module

Initialisation

A default __init__ is automatically provided, which just fills in fields with the arguments passed. For example MyModule(weight, bias, submodule).

Alternatively (quite commonly) you can provide an __init__ method yourself:

class MyModule(equinox.Module):
    weight: jax.Array
    bias: jax.Array
    submodule: equinox.Module

    def __init__(self, in_size, out_size, key):
        wkey, bkey, skey = jax.random.split(key, 3)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))
        self.submodule = equinox.nn.Linear(in_size, out_size, key=skey)

Methods

It is common to create some methods on the class -- for example to define the forward pass of a model.

class MyModule(equinox.Module):
    ...  # as above

    def __call__(self, x):
        return self.submodule(x) + self.weight @ x + self.bias

Tip

You don't have to define __call__:

  • You can define other methods if you want.
  • You can define multiple methods if you want.
  • You can define no methods if you want. (And just use equinox.Module as a nice syntax for custom PyTrees.)

No method is special-cased.

Usage

After you have defined your model, then you can use it just like any other PyTree -- that just happens to have some methods attached. In particular you can pass it around across jax.jit, jax.grad etc. in exactly the way that you're used to.

Example

If you wanted to, then it would be completely safe to do

class MyModule(equinox.Module):
    ...

    @jax.jit
    def __call__(self, x):
        ...

because self is just a PyTree. Unlike most other neural network libraries, you can mix Equinox and native JAX without any difficulties at all.

For fans of strong typing.

Equinox modules are all ABCs by default. This means you can use abc.abstractmethod. You can also create abstract instance attributes or abstract class attributes, see equinox.AbstractVar and equinox.AbstractClassVar.