# Module¤

####  equinox.Module ¤

Base class. Create your model by inheriting from this.

Fields

Specify all its fields at the class level (identical to dataclasses). This defines its children as a PyTree.

class MyModule(equinox.Module):
weight: jax.numpy.ndarray
bias: jax.numpy.ndarray
submodule: equinox.Module


Initialisation

A default __init__ is automatically provided, which just fills in fields with the arguments passed. For example MyModule(weight, bias, submodule).

Alternatively (quite commonly) you can provide an __init__ method yourself:

class MyModule(equinox.Module):
weight: jax.numpy.ndarray
bias: jax.numpy.ndarray
submodule: equinox.Module

def __init__(self, in_size, out_size, key):
wkey, bkey, skey = jax.random.split(key, 3)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
self.submodule = equinox.nn.Linear(in_size, out_size, key=skey)


Methods

It is common to create some methods on the class -- for example to define the forward pass of a model.

class MyModule(equinox.Module):
...  # as above

def __call__(self, x):
return self.submodule(x) + self.weight @ x + self.bias


Tip

You don't have to define __call__:

• You can define other methods if you want.
• You can define multiple methods if you want.
• You can define no methods if you want. (And just use equinox.Module as a nice syntax for custom PyTrees.)

No method is special-cased.

Usage

After you have defined your model, then you can use it just like any other PyTree -- that just happens to have some methods attached. In particular you can pass it around across jax.jit, jax.grad etc. in exactly the way that you're used to.

Example

If you wanted to, then it would be completely safe to do

class MyModule(equinox.Module):
...

@jax.jit
def __call__(self, x):
...


because self is just a PyTree. Unlike most other neural network libaries, you can mix Equinox and native JAX without any difficulties at all.