Compatibility with init-apply libraries¤
Existing JAX neural network libraries have sometimes followed the "init/apply" approach, in which the parameters of a network are initialised with a function
init(), and then the forward pass through a model is specified with
apply(). For example Stax follows this approach.
As a result, some third-party libraries assume that your model is specified by an
init() and an
apply() function, and that the parameters returned from
init() are all JIT-trace-able and grad-able.
Equinox can be made to fit with this style very easily, like so.
import equinox as eqx def make_mlp(in_size, out_size, width_size, depth, *, key): mlp = eqx.nn.MLP( in_size, out_size, width_size, depth, key=key ) # insert your model here params, static = eqx.partition(mlp, eqx.is_inexact_array) def init_fn(): return params def apply_fn(_params, x): model = eqx.combine(_params, static) return model(x) return init_fn, apply_fn
And that's all there is to it.
import jax.numpy as jnp import jax.random as jrandom import jax.tree_util as jtu def main(in_size=2, seed=5678): key = jrandom.PRNGKey(seed) init_fn, apply_fn = make_mlp( in_size=in_size, out_size=1, width_size=8, depth=1, key=key ) x = jnp.arange(in_size) # sample data params = init_fn() y1 = apply_fn(params, x) params = jtu.tree_map(lambda p: p + 1, params) # "stochastic gradient descent" y2 = apply_fn(params, x) assert y1 != y2 main()