Base class. Create your model by inheriting from this.
Specify all its fields at the class level (identical to dataclasses). This defines its children as a PyTree.
class MyModule(equinox.Module): weight: jax.Array bias: jax.Array submodule: equinox.Module
__init__ is automatically provided, which just fills in fields with
the arguments passed. For example
MyModule(weight, bias, submodule).
Alternatively (quite commonly) you can provide an
__init__ method yourself:
class MyModule(equinox.Module): weight: jax.Array bias: jax.Array submodule: equinox.Module def __init__(self, in_size, out_size, key): wkey, bkey, skey = jax.random.split(key, 3) self.weight = jax.random.normal(wkey, (out_size, in_size)) self.bias = jax.random.normal(bkey, (out_size,)) self.submodule = equinox.nn.Linear(in_size, out_size, key=skey)
It is common to create some methods on the class -- for example to define the forward pass of a model.
class MyModule(equinox.Module): ... # as above def __call__(self, x): return self.submodule(x) + self.weight @ x + self.bias
You don't have to define
- You can define other methods if you want.
- You can define multiple methods if you want.
- You can define no methods if you want. (And just use
equinox.Moduleas a nice syntax for custom PyTrees.)
No method is special-cased.
After you have defined your model, then you can use it just like any other
PyTree -- that just happens to have some methods attached. In particular you can
pass it around across
jax.grad etc. in exactly the way that you're
If you wanted to, then it would be completely safe to do
class MyModule(equinox.Module): ... @jax.jit def __call__(self, x): ...
self is just a PyTree. Unlike most other neural network libraries,
you can mix Equinox and native JAX without any difficulties at all.