Module¤
equinox.Module
¤
Base class. Create your model by inheriting from this.
This will make your model a dataclass and a pytree.
Fields
Specify all its fields at the class level (identical to dataclasses). This defines its children as a PyTree.
class MyModule(equinox.Module):
weight: jax.Array
bias: jax.Array
submodule: equinox.Module
Initialisation
A default __init__
is automatically provided, which just fills in fields with
the arguments passed. For example MyModule(weight, bias, submodule)
.
Alternatively (quite commonly) you can provide an __init__
method yourself:
class MyModule(equinox.Module):
weight: jax.Array
bias: jax.Array
submodule: equinox.Module
def __init__(self, in_size, out_size, key):
wkey, bkey, skey = jax.random.split(key, 3)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
self.submodule = equinox.nn.Linear(in_size, out_size, key=skey)
Methods
It is common to create some methods on the class -- for example to define the forward pass of a model.
class MyModule(equinox.Module):
... # as above
def __call__(self, x):
return self.submodule(x) + self.weight @ x + self.bias
Tip
You don't have to define __call__
:
- You can define other methods if you want.
- You can define multiple methods if you want.
- You can define no methods if you want. (And just use
equinox.Module
as a nice syntax for custom PyTrees.)
No method is special-cased.
Usage
After you have defined your model, then you can use it just like any other
PyTree -- that just happens to have some methods attached. In particular you can
pass it around across jax.jit
, jax.grad
etc. in exactly the way that you're
used to.
Example
If you wanted to, then it would be completely safe to do
class MyModule(equinox.Module):
...
@jax.jit
def __call__(self, x):
...
because self
is just a PyTree. Unlike most other neural network libraries,
you can mix Equinox and native JAX without any difficulties at all.
For fans of strong typing.
Equinox modules are all ABCs
by default. This means you can use
abc.abstractmethod
.
You can also create abstract instance attributes or abstract class
attributes, see equinox.AbstractVar
and
equinox.AbstractClassVar
.