Skip to content

Filtered transformations¤

These typically combine equinox.partition, a filter function, and a JAX transformation, all together.

Practically speaking these are usually the only kind of filtering you ever have to use. (But it's good to understand what e.g. equinox.partition and equinox.is_array are doing under the hood, just so that these don't seem too magical.)

equinox.filter_jit(fun = sentinel, *, default = <function is_array>, fn = <function is_array>, args = (), kwargs = None, out = <function is_array>, **jitkwargs) ¤

Wraps together equinox.partition and jax.jit.

Info

By default, all JAX arrays are traced, and all other types are held static.

Arguments:

In each of the following cases, True indicates that an argument should be traced, False indicates that an argument should be held static, and functions Leaf -> bool are mapped and evaluated on every leaf of their subtree.

  • fun is a pure function to JIT compile.
  • default should be a bool or a function Leaf -> bool, and is applied by default to every argument and keyword argument to fun.
  • args is an optional per-argument override for default, and should be a tuple of PyTrees with leaves that are either bools or functions Leaf -> bool. The PyTree structures should be prefixes of the corresponding input to fun.
  • kwargs is an optional per-keyword-argument override for default and should be a dictionary, whose keys are the names of arguments to fun, and whose values are PyTrees with leaves that either bools or functions Leaf -> bool. The PyTree structures should be prefixes of the corresponding input to fun.
  • out should be a PyTree with leaves that either bools or functions Leaf -> bool. The PyTree structure should be a prefix of the output of fun. Truthy values should be tracers; falsey values are any (non-tracer) auxiliary information to return.
  • fn should be a PyTree with leaves that either bools or functions Leaf -> bool. The PyTree structure should be a prefix of fun itself. (Note that fun may be any callable, e.g. a bound method, or a class implementing __call__, and doesn't have to be a normal Python function.)
  • **jitkwargs are any other keyword arguments to jax.jit.

When args, kwargs, out, fn are prefixes of the corresponding input, their value will be mapped over the input PyTree.

Returns:

The JIT'd version of fun.

Example

@eqx.filter_jit
def f(x, y):  # both args traced if arrays, static if non-arrays
    return x + y

@eqx.filter_jit(kwargs=dict(x=False))
def g(x, y):  # x held static; y is traced if array, static if non-array
    return x + y

@eqx.filter_jit(args=(True,))
def h(x):
    return x

@eqx.filter_jit
def apply(f, x):
    return f(x)

f(jnp.array(1), jnp.array(2))  # both args traced
f(jnp.array(1), 2)  # first arg traced, second arg static
f(1, 2)  # both args static

g(1, jnp.array(2))  # first arg static, second arg traced
g(1, 2)  # both args static

h(1)  # traced
h(jnp.array(1))  # traced
h("hi")  # not a trace-able JAX type, so error

apply(lambda x: x + 1, jnp.array(1))  # first arg static, second arg traced.

equinox.filter_grad(fun = sentinel, *, arg = <function is_inexact_array>, **gradkwargs) ¤

Wraps together equinox.partition and jax.grad.

Info

By default, all inexact (floating-point) JAX arrays are differentiated. Any nondifferentiable leaves will have None as the gradient.

Arguments:

  • fun is a pure function to JIT compile.
  • arg is a PyTree whose structure should be a prefix of the structure of the first argument to fun. It behaves as the filter_spec argument to equinox.filter. Truthy values will be differentiated; falsey values will not.
  • **gradkwargs are any other keyword arguments to jax.grad.

Returns:

A function computing the derivative of fun with respect to its first input. Any nondifferentiable leaves will have None as the gradient. See equinox.apply_updates for a convenience function that will only attempt to apply non-None updates.

Tip

If you need to differentiate multiple objects, then put them together into a tuple and pass that through the first argument:

# We want to differentiate `func` with respect to both `x` and `y`.
def func(x, y):
    ...

@equinox.filter_grad
def grad_func(x__y):
    x, y = x__y
    return func(x, y)


equinox.filter_value_and_grad(fun = sentinel, *, arg = <function is_inexact_array>, **gradkwargs) ¤

As equinox.filter_grad, except that it is jax.value_and_grad that is wrapped.


equinox.filter_custom_vjp(fn) ¤

Provides an easier API for jax.custom_vjp, by using filtering.

Usage is:

@equinox.filter_custom_vjp
def fn(vjp_arg, *args, **kwargs):
    # vjp_arg is some PyTree of arbitrary Python objects.
    # args, kwargs contain arbitrary Python objects.
    ...
    return obj  # some PyTree of arbitrary Python objects.

def fn_fwd(vjp_arg, *args, **kwargs):
    ...
    # Should return `obj` as before. `residuals` can be any collection of JAX
    # arrays you want to keep around for the backward pass.
    return obj, residuals

def fn_bwd(residuals, grad_obj, vjp_arg, *args, **kwargs):
    # grad_obj will have `None` as the gradient for any leaves of `obj` that were
    # not JAX arrays
    ...
    # grad_vjp_arg should have `None` as the gradient for any leaves of `vjp_arg`
    # that were not JAX arrays.
    return grad_vjp_arg

fn.defvjp(fn_fwd, fn_bwd)

The key differences to jax.custom_vjp are that:

  • Only the gradient of the first argument, vjp_arg, should be computed on the backward pass. Everything else will automatically have zero gradient.
  • You do not need to distinguish differentiable from nondifferentiable manually. Instead you should return gradients for all inexact JAX arrays in the first argument. (And just put None on every other leaf of the PyTree.)
  • As a convenience, all of the inputs from the forward pass are additionally made available to you on the backward pass.

Tip

If you need gradients with respect to multiple arguments, then just pack them together as a tuple via the first argument vjp_arg. (See also equinox.filter_grad for a similar trick.)


equinox.filter_vmap(fun = sentinel, *, default = <function _zero_if_array_else_none>, fn = None, args = (), kwargs = None, out = <function _zero_if_array_else_none>, **vmapkwargs) ¤

Wraps together equinox.partition and jax.vmap.

Info

By default, all JAX arrays are vectorised down their leading axis (i.e. axis index 0), and all other types are not vectorised.

Arguments:

In each of the following cases, then int indicates an array axis to vectorise over, None indicates that an argument should be broadcast (not vectorised over), and functions Leaf -> Union[None, int] are mapped and evaluated on every leaf of their subtree. None should be used for non-JAX-array arguments.

  • fun is a pure function to vectorise.
  • default should be a Union[None, int] or a function Leaf -> Union[None, int], and is applied by default to every argument and keyword argument to fun.
  • args is an optional per-argument override for default, and should be a tuple of PyTrees with leaves that are either Union[None, int]s or functions Leaf -> Union[None, int]. The PyTree structures should be prefixes of the corresponding input to fun.
  • kwargs is an optional per-keyword-argument override for default and should be a dictionary, whose keys are the names of arguments to fun, and whose values are PyTrees with leaves that are either Union[None, int]s or functions Leaf -> Union[None, int]. The PyTree structures should be prefixes of the corresponding input to fun.
  • out should be a PyTree with leaves that are either Union[None, int]s or functions Leaf -> Union[None, int]. The PyTree structure should be a prefix of the output of fun.
  • fn should be a PyTree with leaves that are either Union[None, int]s or functions Leaf -> Union[None, int]. The PyTree structure should be a prefix of fun itself. (Note that fun may be any callable, e.g. a bound method, or a class implementing __call__, and doesn't have to be a normal Python function.)
  • **vmapkwargs are any other keyword arguments to jax.vmap.

When args, kwargs, out, fn are prefixes of the corresponding input, their value will be mapped over the input PyTree.

Returns:

The vectorised version of fun.

Info

In fact, besides None, int and Leaf -> Union[None, int]: boolean types are also supported and treated identically to None. This is to support seamlessly switching between equinox.filter_pmap and equinox.filter_vmap if desired.

Example

import equinox as eqx
import jax.numpy as jnp

@eqx.filter_vmap
def f(x, y):
    return x + y

@eqx.filter_vmap(kwargs=dict(x=1))
def g(x, y):
    return x + y

@eqx.filter_vmap(args=(None,))
def h(x, y):
    return x + y

f(jnp.array([1, 2]), jnp.array([3, 4]))  # both args vectorised down axis 0
f(jnp.array([1, 2]), 3)  # first arg vectorised down axis 0
                         # second arg broadcasted

g(jnp.array([[1, 2]]), jnp.array([3, 4]))  # first arg vectorised down axis 1
                                           # second arg vectorised down axis 0

h(jnp.array(1), jnp.array([2, 3]))  # first arg broadcasted
                                    # second arg vectorised down axis 0

Example

filter_vmap can be used to easily create ensembles of models. For example, here's an ensemble of eight MLPs:

import equinox as eqx
import jax.random as jr

key = jr.PRNGKey(0)
keys = jr.split(key, 8)

# Create an ensemble of models

@eqx.filter_vmap
def make_ensemble(key):
    return eqx.nn.MLP(2, 2, 2, 2, key=key)

mlp_ensemble = make_ensemble(keys)

# Evaluate each member of the ensemble on the same data

@eqx.filter_vmap(kwargs=dict(x=None))
def evaluate_ensemble(model, x):
    return model(x)

evaluate_ensemble(mlp_ensemble, jr.normal(key, (2,)))

# Evaluate each member of the ensemble on different data

@eqx.filter_vmap
def evaluate_per_ensemble(model, x):
    return model(x)

evaluate_per_ensemble(mlp_ensemble, jr.normal(key, (8, 2)))

Here, make_ensemble works because equinox.nn.MLP is a PyTree, and so it is a valid output from a filter_vmap. This PyTree includes some JAX arrays (the weights and biases) and some non-JAX-arrays (e.g. activation functions). filter_vmap will vectorise the JAX arrays (with separate weights for each member of the ensemble) whilst leaving the non-JAX-arrays alone.

Note that as the weights in mlp_ensemble now have a leading batch dimension -- that the weights of eqx.nn.MLP instances do not typically have -- then it cannot be called directly. It must instead be passed back into a vectorised region to be called.


equinox.filter_pmap(fun = sentinel, *, default = <function _zero_if_array_else_none>, fn = None, args = (), kwargs = None, out = <function _zero_if_array_else_none>, **pmapkwargs) ¤

Wraps together equinox.partition and jax.pmap.

Info

By default, the computation is parallelised by splitting all JAX arrays down their leading axis (i.e. axis index 0), and broadcasting all other types to each replica.

Arguments:

In each of the following cases, then int indicates an array axis to split down, None indicates that an argument should be broadcast to each device (not split across devices), and functions Leaf -> Union[None, bool, int] are mapped and evaluated on every leaf of their subtree.

Note that jax.pmap, and thus equinox.filter_pmap, also JIT-compile their function in the same way as jax.jit. By default, all JAX arrays are traced and all other arrays are treated as static inputs. This may be controlled explicitly -- instead of just passing None -- by passing either True (traced) or False (static).

None, False and True should be used for non-JAX-array arguments.

  • fun is a pure function to parallelise.
  • default should be a Union[None, bool, int] or a function Leaf -> Union[None, bool, int], and is applied by default to every argument and keyword argument to fun.
  • args is an optional per-argument override for default, and should be a tuple of PyTrees with leaves that are either Union[None, bool, int]s or functions Leaf -> Union[None, bool, int]. The PyTree structures should be prefixes of the corresponding input to fun.
  • kwargs is an optional per-keyword-argument override for default and should be a dictionary, whose keys are the names of arguments to fun, and whose values are PyTrees with leaves that are either Union[None, bool, int]s or functions Leaf -> Union[None, bool, int]. The PyTree structures should be prefixes of the corresponding input to fun.
  • out should be a PyTree with leaves that are either Union[None, bool, int]s or functions Leaf -> Union[None, bool, int]. The PyTree structure should be a prefix of the output of fun. True indicates a tracer, False indicates any auxiliary information to return.
  • fn should be a PyTree with leaves that are either Union[None, bool, int]s or functions Leaf -> Union[None, bool, int]. The PyTree structure should be a prefix of fun itself. (Note that fun may be any callable, e.g. a bound method, or a class implementing __call__, and doesn't have to be a normal Python function.)
  • **pmapkwargs are any other keyword arguments to jax.pmap.

When args, kwargs, out, fn are prefixes of the corresponding input, their value will be mapped over the input PyTree.

Returns:

The parallelised version of fun.

Example

import equinox as eqx
import jax.numpy as jnp

@eqx.filter_pmap
def f(x, y):
    return x + y

@eqx.filter_pmap(kwargs=dict(x=1))
def g(x, y):
    return x + y

@eqx.filter_pmap(args=(None,))
def h(x, y):
    return x + y

@eqx.filter_pmap
def apply(fun, x):
    return fun(x)

f(jnp.array([1, 2]), jnp.array([3, 4]))  # both args split down axis 0
f(jnp.array([1, 2]), 3)  # first arg split down axis 0
                         # second arg broadcasted

g(jnp.array([[1, 2]]), jnp.array([3, 4]))  # first arg split down axis 1
                                           # second arg split down axis 0

h(jnp.array(1), jnp.array([2, 3]))  # first arg broadcasted
                                    # second arg split down axis 0

apply(lambda x: x + 1, jnp.array([2, 3]))  # first arg broadcasted (as it's not
                                           # a JAX array)
                                           # second arg split down axis 0

equinox.filter_eval_shape(fun: Callable, *args, **kwargs) ¤

As jax.eval_shape, but allows any Python object as inputs and outputs.

(jax.eval_shape is constrained to only work with JAX arrays, Python float/int/etc.)