Skip to content

Compatibility with init-apply librariesยค

Existing JAX neural network libraries have sometimes followed the "init/apply" approach, in which the parameters of a network are initialised with a function init(), and then the forward pass through a model is specified with apply(). For example Stax follows this approach.

As a result, some third-party libraries assume that your model is specified by an init() and an apply() function, and that the parameters returned from init() are all JIT-trace-able and grad-able.

Equinox can be made to fit with this style very easily, like so.

import equinox as eqx

def make_mlp(in_size, out_size, width_size, depth, *, key):
    mlp = eqx.nn.MLP(
        in_size, out_size, width_size, depth, key=key
    )  # insert your model here
    params, static = eqx.partition(mlp, eqx.is_inexact_array)

    def init_fn():
        return params

    def apply_fn(_params, x):
        model = eqx.combine(_params, static)
        return model(x)

    return init_fn, apply_fn

And that's all there is to it.

Example usage:

import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu

def main(in_size=2, seed=5678):
    key = jrandom.PRNGKey(seed)

    init_fn, apply_fn = make_mlp(
        in_size=in_size, out_size=1, width_size=8, depth=1, key=key

    x = jnp.arange(in_size)  # sample data
    params = init_fn()
    y1 = apply_fn(params, x)
    params = jtu.tree_map(lambda p: p + 1, params)  # "stochastic gradient descent"
    y2 = apply_fn(params, x)
    assert y1 != y2