Module¤
equinox.Module
¤
Base class. Create your model by inheriting from this.
Fields
Specify all its fields at the class level (identical to dataclasses). This defines its children as a PyTree.
class MyModule(equinox.Module):
weight: jax.numpy.ndarray
bias: jax.numpy.ndarray
submodule: equinox.Module
Initialisation
A default __init__
is automatically provided, which just fills in fields with the
arguments passed. For example MyModule(weight, bias, submodule)
.
Alternatively (quite commonly) you can provide an __init__
method yourself:
class MyModule(equinox.Module):
weight: jax.numpy.ndarray
bias: jax.numpy.ndarray
submodule: equinox.Module
def __init__(self, in_size, out_size, key):
wkey, bkey, skey = jax.random.split(key, 3)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
self.submodule = equinox.nn.Linear(in_size, out_size, key=skey)
Methods
It is common to create some methods on the class -- for example to define the forward pass of a model.
class MyModule(equinox.Module):
... # as above
def __call__(self, x):
return self.submodule(x) + self.weight @ x + self.bias
Tip
You don't have to define __call__
:
- You can define other methods if you want.
- You can define multiple methods if you want.
- You can define no methods if you want. (And just use
equinox.Module
as a nice syntax for custom PyTrees.)
No method is special-cased.
Usage
After you have defined your model, then you can use it just like any other PyTree
-- that just happens to have some methods attached. In particular you can pass it
around across jax.jit
, jax.grad
etc. in exactly the way that you're used to.
Example
If you wanted to, then it would be completely safe to do
class MyModule(equinox.Module):
...
@jax.jit
def __call__(self, x):
...
because self
is just a PyTree. Unlike most other neural network libaries, you
can mix Equinox and native JAX without any difficulties at all.