# Filtered transformations¤

These typically combine `equinox.partition`

, `equinox.combine`

, and a JAX transformation, all together.

Generally speaking, this means producing an enhanced version of the JAX transformation, that operates on arbitrary PyTrees instead of specifically just JAX arrays.

Practically speaking these are usually the only kind of filtering you ever have to use. (But it's good to understand what e.g. `equinox.partition`

and `equinox.is_array`

are doing under the hood, just so that these don't seem too magical.)

## Just-in-time compilation¤

####
`equinox.filter_jit(fun = sentinel, *, default = <function is_array>, fn = <function is_array>, args = (), kwargs = None, out = <function is_array>, **jitkwargs)`

¤

Wraps together `equinox.partition`

and `jax.jit`

.

Info

By default, all JAX arrays are traced, and all other types are held static.

**Arguments:**

In each of the following cases, `True`

indicates that an argument should be traced,
`False`

indicates that an argument should be held static, and functions
`Leaf -> bool`

are mapped and evaluated on every leaf of their subtree.

`fun`

is a pure function to JIT compile.`default`

should be a`bool`

or a function`Leaf -> bool`

, and is applied by default to every argument and keyword argument to`fun`

.`args`

is an optional per-argument override for`default`

, and should be a tuple of PyTrees with leaves that are either`bool`

s or functions`Leaf -> bool`

. The PyTree structures should be prefixes of the corresponding input to`fun`

.`kwargs`

is an optional per-keyword-argument override for`default`

and should be a dictionary, whose keys are the names of arguments to`fun`

, and whose values are PyTrees with leaves that either`bool`

s or functions`Leaf -> bool`

. The PyTree structures should be prefixes of the corresponding input to`fun`

.`out`

should be a PyTree with leaves that either`bool`

s or functions`Leaf -> bool`

. The PyTree structure should be a prefix of the output of`fun`

. Truthy values should be tracers; falsey values are any (non-tracer) auxiliary information to return.`fn`

should be a PyTree with leaves that either`bool`

s or functions`Leaf -> bool`

. The PyTree structure should be a prefix of`fun`

itself. (Note that`fun`

may be any callable, e.g. a bound method, or a class implementing`__call__`

, and doesn't have to be a normal Python function.)`**jitkwargs`

are any other keyword arguments to`jax.jit`

.

When `args`

, `kwargs`

, `out`

, `fn`

are prefixes of the corresponding input, their
value will be mapped over the input PyTree.

**Returns:**

The JIT'd version of `fun`

.

Example

```
@eqx.filter_jit
def f(x, y): # both args traced if arrays, static if non-arrays
return x + y
@eqx.filter_jit(kwargs=dict(x=False))
def g(x, y): # x held static; y is traced if array, static if non-array
return x + y
@eqx.filter_jit(args=(True,))
def h(x):
return x
@eqx.filter_jit
def apply(f, x):
return f(x)
f(jnp.array(1), jnp.array(2)) # both args traced
f(jnp.array(1), 2) # first arg traced, second arg static
f(1, 2) # both args static
g(1, jnp.array(2)) # first arg static, second arg traced
g(1, 2) # both args static
h(1) # traced
h(jnp.array(1)) # traced
h("hi") # not a trace-able JAX type, so error
apply(lambda x: x + 1, jnp.array(1)) # first arg static, second arg traced.
```

####
`equinox.filter_make_jaxpr(fun)`

¤

As `jax.make_jaxpr`

, but accepts arbitrary PyTrees as input and output.

**Arguments:**

`fun`

: The function`fun(*arg, **kwargs)`

whose jaxpr is to be computed. Its positional and keyword arguments may be anything, as can its return value.

**Returns:**

A wrapped version of `fun`

, that when applied to example arguments
`*args, **kwargs`

, will return a 3-tuple of:

- A
`ClosedJaxpr`

representing the evaluation of that function on those arguments. - A
`PyTree[jax.ShapeDtypeStruct]`

representing the output shape and dtype of the result. - A
`PyTree[Any]`

representing any non-array outputs from`fun`

.

The example arguments to be traced may be anything with `.shape`

and `.dtype`

fields (typically JAX arrays, NumPy arrays, of `jax.ShapeDtypeStruct`

s). All
other argments are treated statically. In particular, Python builtins (`bool`

,
`int`

, `float`

, `complex`

) are treated as static inputs; wrap them in JAX/NumPy
arrays if you would like them to be traced.

####
`equinox.filter_eval_shape(fun: Callable, *args, **kwargs)`

¤

As `jax.eval_shape`

, but allows any Python object as inputs and outputs.

(`jax.eval_shape`

is constrained to only work with JAX arrays, Python float/int/etc.)

## Automatic differentiation¤

####
`equinox.filter_grad(fun = sentinel, *, arg = <function is_inexact_array>, **gradkwargs)`

¤

As `jax.grad`

, but accepts arbitrary PyTrees as inputs. (Not just JAXable types.)

Info

By default, all inexact (floating-point) JAX arrays are differentiated. Any
nondifferentiable leaves will have `None`

as the gradient.

**Arguments:**

`fun`

is a pure function to JIT compile.`arg`

is a PyTree whose structure should be a prefix of the structure of the**first**argument to`fun`

. It behaves as the`filter_spec`

argument to`equinox.filter`

. Truthy values will be differentiated; falsey values will not.`**gradkwargs`

are any other keyword arguments to`jax.grad`

.

**Returns:**

A function computing the derivative of `fun`

with respect to its first input. Any
nondifferentiable leaves will have `None`

as the gradient. See
`equinox.apply_updates`

for a convenience function that will only attempt to
apply non-`None`

updates.

Tip

If you need to differentiate multiple objects, then put them together into a tuple and pass that through the first argument:

```
# We want to differentiate `func` with respect to both `x` and `y`.
def func(x, y):
...
@equinox.filter_grad
def grad_func(x__y):
x, y = x__y
return func(x, y)
```

####
`equinox.filter_value_and_grad(fun = sentinel, *, arg = <function is_inexact_array>, **gradkwargs)`

¤

As `equinox.filter_grad`

, except that it is `jax.value_and_grad`

that is
wrapped.

####
`equinox.filter_jvp(fn, primals, tangents)`

¤

Like `jax.jvp`

, but accepts arbitrary PyTrees. (Not just JAXable types.)

**Arguments:**

`fn`

: Function to be differentiated. Its arguments can be Python objects, and its return type can be any Python object.`primals`

: The primal values at which`fn`

should be evaluated. Should be a sequence of arguments, and its length should be equal to the number of positional parameter of`fn`

.`tangents`

: The tangent vector for which the Jacobian-vector product should be calculated. Should be a PyTree with the same structure as`primals`

. The leaves of`tangents`

must be either floating-point JAX arrays, or Python floats, or`None`

s. The tangent must be`None`

for any primal which is not itself a floating-point JAX array or Python float.

**Returns:**

A pair `(primals_out, tangents_out)`

is returned,
where `primals_out = fn(*primals)`

and `tangents_out`

is the Jacobian-vector
product of `fn`

evaluated at `primals`

with `tangents`

.

The `tangents_out`

has the same structure as `primals_out`

, but has `None`

for
any leaves that aren't differentiable.

Tip

Unlike `jax.jvp`

, this function does not support a `has_aux`

argument. It isn't
needed, as unlike `jax.jvp`

the output of this function can be of arbitrary type.

####
`equinox.filter_vjp(fun, *primals, *, has_aux = False)`

¤

Filtered version of `jax.vjp`

.

**Arguments:**

`fun`

: The function to be differentiated. Will be called as`fun(*primals)`

. Can return an arbitrary PyTree.`primals`

: The arguments at which`fun`

will be evaluated and differentiated. Can be arbitrary PyTrees.`has_aux`

: Indicates whether`fun`

returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to`False`

.

**Returns:**

If `has_aux is False`

then returns a `(primals_out, vjpfun)`

pair, where
`primals_out = fun(*primals)`

and `vjpfun`

is a function from a cotangent vector
with the same shape as `primals_out`

to a tuple of cotangent vectors with the same
shape as `primals`

, representing the vector-Jacobian product of `fun`

evaluated at
`primals`

.

If `has_aux is True`

then returns a tuple `(primals_out, vjpfun, aux)`

, where `aux`

is the auxiliary data returned from `fun`

.

The cotangent passed to `vjpfun`

should have arrays corresponding to all
floating-point arrays in `primals_out`

, and `None`

for all other PyTree leaves. The
cotangents returned from `vjpfun`

will likewise have arrays for all `primals`

that
are floating-point arrays, and `None`

for all other PyTree leaves.

####
`equinox.filter_custom_jvp(fn)`

¤

Filtered version of `jax.custom_jvp`

.

Works in the same way as `jax.custom_jvp`

, except that you do not need to specify
`nondiff_argnums`

. Instead, arguments are automatically split into differentiable
and nondifferentiable based on whether or not they are a floating-point JAX array.

The tangents of the nondifferentiable arguments will be passed as `None`

.

The return types must still all be JAX types.

**Examples:**

```
@equinox.filter_custom_jvp
def call(fn, x):
return fn(x)
@call.defjvp
def call_jvp(primals, tangents):
fn, x = primals
_, tx = tangents
primal_out = call(fn, x)
tangent_out = tx**2
return primal_out, tangent_out
```

####
`equinox.filter_custom_vjp(fn)`

¤

As `jax.custom_vjp`

, but with a nicer interface.

Usage is:

```
@equinox.filter_custom_vjp
def fn(vjp_arg, *args, **kwargs):
# vjp_arg is some PyTree of arbitrary Python objects.
# args, kwargs contain arbitrary Python objects.
...
return obj # some PyTree of arbitrary Python objects.
def fn_fwd(vjp_arg, *args, **kwargs):
...
# Should return `obj` as before. `residuals` can be any collection of JAX
# arrays you want to keep around for the backward pass.
return obj, residuals
def fn_bwd(residuals, grad_obj, vjp_arg, *args, **kwargs):
# grad_obj will have `None` as the gradient for any leaves of `obj` that were
# not JAX arrays
...
# grad_vjp_arg should have `None` as the gradient for any leaves of `vjp_arg`
# that were not JAX arrays.
return grad_vjp_arg
fn.defvjp(fn_fwd, fn_bwd)
```

The key differences to `jax.custom_vjp`

are that:

- Only the gradient of the first argument,
`vjp_arg`

, should be computed on the backward pass. Everything else will automatically have zero gradient. - You do not need to distinguish differentiable from nondifferentiable manually.
Instead you should return gradients for all inexact JAX arrays in the first
argument. (And just put
`None`

on every other leaf of the PyTree.) - As a convenience, all of the inputs from the forward pass are additionally made available to you on the backward pass.

Tip

If you need gradients with respect to multiple arguments, then just pack them
together as a tuple via the first argument `vjp_arg`

. (See also
`equinox.filter_grad`

for a similar trick.)

####
`equinox.filter_closure_convert(fn, *args, **kwargs)`

¤

As `jax.closure_convert`

, but works on functions accepting and returning
arbitrary PyTree objects. In addition, all JAX arrays are hoisted into constants
(not just floating point arrays).

This is useful for explicitly capturing any closed-over JAX tracers
before crossing an API boundary, such as `jax.grad`

, `jax.custom_vjp`

, or the
rule of a custom primitive.

**Arguments:**

`fn`

: The function to call. Will be called as`fun(*args, **kwargs)`

.`args`

,`kwargs`

: Example arguments at which to call the function. The function is not actually evaluated on these arguments; all JAX arrays are subsituted for tracers. Note that Python builtins (`bool`

,`int`

,`float`

,`complex`

) are not substituted for tracers and are passed through as-is.

**Returns:**

A new function, which can be called in the same way, using `*args`

and `**kwargs`

.
Will contain all closed-over tracers of `fn`

as part of its PyTree structure.

Example

```
@jax.grad
def f(x, y):
z = x + y
g = lambda a: z + a # closes over z
g2 = filter_closure_convert(g, 1)
assert [id(b) for b in g2.consts] == [id(z)]
return z
f(1., 1.)
```

## Vectorisation and parallelisation¤

####
`equinox.filter_vmap(fun = sentinel, *, default = <function _zero_if_array_else_none>, fn = None, args = (), kwargs = None, out = <function _zero_if_array_else_none>, **vmapkwargs)`

¤

Wraps together `equinox.partition`

and `jax.vmap`

.

Info

By default, all JAX arrays are vectorised down their leading axis (i.e. axis index 0), and all other types are not vectorised.

**Arguments:**

In each of the following cases, then `int`

indicates an array axis to vectorise
over, `None`

indicates that an argument should be broadcast (not vectorised
over), and functions `Leaf -> Union[None, int]`

are mapped and evaluated on every
leaf of their subtree. `None`

should be used for non-JAX-array arguments.

`fun`

is a pure function to vectorise.`default`

should be a`Union[None, int]`

or a function`Leaf -> Union[None, int]`

, and is applied by default to every argument and keyword argument to`fun`

.`args`

is an optional per-argument override for`default`

, and should be a tuple of PyTrees with leaves that are either`Union[None, int]`

s or functions`Leaf -> Union[None, int]`

. The PyTree structures should be prefixes of the corresponding input to`fun`

.`kwargs`

is an optional per-keyword-argument override for`default`

and should be a dictionary, whose keys are the names of arguments to`fun`

, and whose values are PyTrees with leaves that are either`Union[None, int]`

s or functions`Leaf -> Union[None, int]`

. The PyTree structures should be prefixes of the corresponding input to`fun`

.`out`

should be a PyTree with leaves that are either`Union[None, int]`

s or functions`Leaf -> Union[None, int]`

. The PyTree structure should be a prefix of the output of`fun`

.`fn`

should be a PyTree with leaves that are either`Union[None, int]`

s or functions`Leaf -> Union[None, int]`

. The PyTree structure should be a prefix of`fun`

itself. (Note that`fun`

may be any callable, e.g. a bound method, or a class implementing`__call__`

, and doesn't have to be a normal Python function.)`**vmapkwargs`

are any other keyword arguments to`jax.vmap`

.

When `args`

, `kwargs`

, `out`

, `fn`

are prefixes of the corresponding input, their
value will be mapped over the input PyTree.

**Returns:**

The vectorised version of `fun`

.

Info

In fact, besides `None`

, `int`

and `Leaf -> Union[None, int]`

: boolean
types are also supported and treated identically to `None`

. This is to support
seamlessly switching between `equinox.filter_pmap`

and
`equinox.filter_vmap`

if desired.

Example

```
import equinox as eqx
import jax.numpy as jnp
@eqx.filter_vmap
def f(x, y):
return x + y
@eqx.filter_vmap(kwargs=dict(x=1))
def g(x, y):
return x + y
@eqx.filter_vmap(args=(None,))
def h(x, y):
return x + y
f(jnp.array([1, 2]), jnp.array([3, 4])) # both args vectorised down axis 0
f(jnp.array([1, 2]), 3) # first arg vectorised down axis 0
# second arg broadcasted
g(jnp.array([[1, 2]]), jnp.array([3, 4])) # first arg vectorised down axis 1
# second arg vectorised down axis 0
h(jnp.array(1), jnp.array([2, 3])) # first arg broadcasted
# second arg vectorised down axis 0
```

Example

`filter_vmap`

can be used to easily create ensembles of models. For example, here's an
ensemble of eight MLPs:

```
import equinox as eqx
import jax.random as jr
key = jr.PRNGKey(0)
keys = jr.split(key, 8)
# Create an ensemble of models
@eqx.filter_vmap
def make_ensemble(key):
return eqx.nn.MLP(2, 2, 2, 2, key=key)
mlp_ensemble = make_ensemble(keys)
# Evaluate each member of the ensemble on the same data
@eqx.filter_vmap(kwargs=dict(x=None))
def evaluate_ensemble(model, x):
return model(x)
evaluate_ensemble(mlp_ensemble, jr.normal(key, (2,)))
# Evaluate each member of the ensemble on different data
@eqx.filter_vmap
def evaluate_per_ensemble(model, x):
return model(x)
evaluate_per_ensemble(mlp_ensemble, jr.normal(key, (8, 2)))
```

Here, `make_ensemble`

works because `equinox.nn.MLP`

is a PyTree, and so it
is a valid output from a `filter_vmap`

. This PyTree includes some JAX arrays
(the weights and biases) and some non-JAX-arrays (e.g. activation functions).
`filter_vmap`

will vectorise the JAX arrays (with separate weights for each
member of the ensemble) whilst leaving the non-JAX-arrays alone.

Note that as the weights in `mlp_ensemble`

now have a leading batch dimension
-- that the weights of `eqx.nn.MLP`

instances do not typically have -- then it
cannot be called directly. It must instead be passed back into a vectorised
region to be called.

####
`equinox.filter_pmap(fun = sentinel, *, default = <function _zero_if_array_else_none>, fn = None, args = (), kwargs = None, out = <function _zero_if_array_else_none>, **pmapkwargs)`

¤

Wraps together `equinox.partition`

and `jax.pmap`

.

Info

By default, the computation is parallelised by splitting all JAX arrays down their leading axis (i.e. axis index 0), and broadcasting all other types to each replica.

**Arguments:**

In each of the following cases, then `int`

indicates an array axis to split down,
`None`

indicates that an argument should be broadcast to each device (not split
across devices), and functions `Leaf -> Union[None, bool, int]`

are mapped and
evaluated on every leaf of their subtree.

Note that `jax.pmap`

, and thus `equinox.filter_pmap`

, also JIT-compile their
function in the same way as `jax.jit`

. By default, all JAX arrays are traced and
all other arrays are treated as static inputs. This may be controlled explicitly
-- instead of just passing `None`

-- by passing either `True`

(traced) or
`False`

(static).

`None`

, `False`

and `True`

should be used for non-JAX-array arguments.

`fun`

is a pure function to parallelise.`default`

should be a`Union[None, bool, int]`

or a function`Leaf -> Union[None, bool, int]`

, and is applied by default to every argument and keyword argument to`fun`

.`args`

is an optional per-argument override for`default`

, and should be a tuple of PyTrees with leaves that are either`Union[None, bool, int]`

s or functions`Leaf -> Union[None, bool, int]`

. The PyTree structures should be prefixes of the corresponding input to`fun`

.`kwargs`

is an optional per-keyword-argument override for`default`

and should be a dictionary, whose keys are the names of arguments to`fun`

, and whose values are PyTrees with leaves that are either`Union[None, bool, int]`

s or functions`Leaf -> Union[None, bool, int]`

. The PyTree structures should be prefixes of the corresponding input to`fun`

.`out`

should be a PyTree with leaves that are either`Union[None, bool, int]`

s or functions`Leaf -> Union[None, bool, int]`

. The PyTree structure should be a prefix of the output of`fun`

.`True`

indicates a tracer,`False`

indicates any auxiliary information to return.`fn`

should be a PyTree with leaves that are either`Union[None, bool, int]`

s or functions`Leaf -> Union[None, bool, int]`

. The PyTree structure should be a prefix of`fun`

itself. (Note that`fun`

may be any callable, e.g. a bound method, or a class implementing`__call__`

, and doesn't have to be a normal Python function.)`**pmapkwargs`

are any other keyword arguments to`jax.pmap`

.

When `args`

, `kwargs`

, `out`

, `fn`

are prefixes of the corresponding input, their
value will be mapped over the input PyTree.

**Returns:**

The parallelised version of `fun`

.

Example

```
import equinox as eqx
import jax.numpy as jnp
@eqx.filter_pmap
def f(x, y):
return x + y
@eqx.filter_pmap(kwargs=dict(x=1))
def g(x, y):
return x + y
@eqx.filter_pmap(args=(None,))
def h(x, y):
return x + y
@eqx.filter_pmap
def apply(fun, x):
return fun(x)
f(jnp.array([1, 2]), jnp.array([3, 4])) # both args split down axis 0
f(jnp.array([1, 2]), 3) # first arg split down axis 0
# second arg broadcasted
g(jnp.array([[1, 2]]), jnp.array([3, 4])) # first arg split down axis 1
# second arg split down axis 0
h(jnp.array(1), jnp.array([2, 3])) # first arg broadcasted
# second arg split down axis 0
apply(lambda x: x + 1, jnp.array([2, 3])) # first arg broadcasted (as it's not
# a JAX array)
# second arg split down axis 0
```