Compatibility with init-apply librariesยค
Existing JAX neural network libraries have sometimes followed the "init/apply" approach, in which the parameters of a network are initialised with a function init()
, and then the forward pass through a model is specified with apply()
. For example Stax follows this approach.
As a result, some third-party libraries assume that your model is specified by an init()
and an apply()
function, and that the parameters returned from init()
are all JIT-trace-able and grad-able.
Equinox can be made to fit with this style very easily, like so.
This example is available as a Jupyter notebook here.
import equinox as eqx
def make_mlp(in_size, out_size, width_size, depth, *, key):
mlp = eqx.nn.MLP(
in_size, out_size, width_size, depth, key=key
) # insert your model here
params, static = eqx.partition(mlp, eqx.is_inexact_array)
def init_fn():
return params
def apply_fn(_params, x):
model = eqx.combine(_params, static)
return model(x)
return init_fn, apply_fn
And that's all there is to it.
Example usage:
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu
def main(in_size=2, seed=5678):
key = jrandom.PRNGKey(seed)
init_fn, apply_fn = make_mlp(
in_size=in_size, out_size=1, width_size=8, depth=1, key=key
)
x = jnp.arange(in_size) # sample data
params = init_fn()
y1 = apply_fn(params, x)
params = jtu.tree_map(lambda p: p + 1, params) # "stochastic gradient descent"
y2 = apply_fn(params, x)
assert y1 != y2
main()