Base class. Create your model by inheriting from this.
Specify all its fields at the class level (identical to dataclasses). This defines its children as a PyTree.
class MyModule(equinox.Module): weight: jax.numpy.ndarray bias: jax.numpy.ndarray submodule: equinox.Module
__init__ is automatically provided, which just fills in fields with the
arguments passed. For example
MyModule(weight, bias, submodule).
Alternatively (quite commonly) you can provide an
__init__ method yourself:
class MyModule(equinox.Module): weight: jax.numpy.ndarray bias: jax.numpy.ndarray submodule: equinox.Module def __init__(self, in_size, out_size, key): wkey, bkey, skey = jax.random.split(key, 3) self.weight = jax.random.normal(wkey, (out_size, in_size)) self.bias = jax.random.normal(bkey, (out_size,)) self.submodule = equinox.nn.Linear(in_size, out_size, key=skey)
It is common to create some methods on the class -- for example to define the forward pass of a model.
class MyModule(equinox.Module): ... # as above def __call__(self, x): return self.submodule(x) + self.weight @ x + self.bias
You don't have to define
- You can define other methods if you want.
- You can define multiple methods if you want.
- You can define no methods if you want. (And just use
equinox.Moduleas a nice syntax for custom PyTrees.)
No method is special-cased.
After you have defined your model, then you can use it just like any other PyTree
-- that just happens to have some methods attached. In particular you can pass it
jax.grad etc. in exactly the way that you're used to.
If you wanted to, then it would be completely safe to do
class MyModule(equinox.Module): ... @jax.jit def __call__(self, x): ...
self is just a PyTree. Unlike most other neural network libaries, you
can mix Equinox and native JAX without any difficulties at all.