# Transformations¤

These offer an alternate (easier to use) API for JAX transformations.

For example, JAX uses `jax.jit(..., static_argnums=...)`

to manually indicate which arguments should be treated dynamically/statically. Meanwhile `equinox.filter_jit`

automatically treats all JAX/NumPy arrays dynamically, and everything else statically. Moreover, this is done at the level of individual PyTree leaves, so that unlike `jax.jit`

, one argment can have both dynamic (array-valued) and static leaves.

Most users find that this is a simpler API when working with complicated PyTrees, such as are produced when using Equinox modules. But you can also still use Equinox with normal `jax.jit`

etc. if you so prefer.

## Just-in-time compilation¤

####
`equinox.filter_jit(fun = sentinel, *, donate: Literal['all', 'all-except-first', 'warn', 'warn-except-first', 'none'] = 'none')`

¤

An easier-to-use version of `jax.jit`

. All JAX and NumPy arrays are traced, and
all other types are held static.

**Arguments:**

`fun`

is a pure function to JIT compile.`donate`

indicates whether the buffers of JAX arrays are donated or not. It should either be:`'all'`

: donate all arrays and suppress all warnings about unused buffers;`'all-except-first'`

: donate all arrays except for those in the first argument, and suppress all warnings about unused buffers;`'warn'`

: as above, but don't suppress unused buffer warnings;`'warn-except-first'`

: as above, but don't suppress unused buffer warnings;`'none'`

: no buffer donation. (This the default.)

**Returns:**

The JIT'd version of `fun`

.

Example

```
# Basic behaviour
@eqx.filter_jit
def f(x, y): # both args traced if arrays, static if non-arrays
return x + y, x - y
f(jnp.array(1), jnp.array(2)) # both args traced
f(jnp.array(1), 2) # first arg traced, second arg static
f(1, 2) # both args static
```

Info

Donating arguments allows their underlying memory to be used in the computation. This can produce speed and memory improvements, but means that you cannot use any donated arguments again, as their underlying memory has been overwritten. (JAX will throw an error if you try to.)

Info

If you want to trace Python `bool`

/`int`

/`float`

/`complex`

as well then you
can do this by wrapping them into a JAX array: `jnp.asarray(x)`

.

If you want to donate only some arguments then this can be done by setting
`filter_jit(donate="all-except-first")`

and then passing all arguments that you
don't want to donate through the first argument. (Packing multiple values into
a tuple if necessary.)

####
`equinox.filter_make_jaxpr(fun: Callable[~_P, Any]) -> Callable[~_P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]]`

¤

As `jax.make_jaxpr`

, but accepts arbitrary PyTrees as input and output.

**Arguments:**

`fun`

: The function`fun(*arg, **kwargs)`

whose jaxpr is to be computed. Its positional and keyword arguments may be anything, as can its return value.

**Returns:**

A wrapped version of `fun`

, that when applied to example arguments
`*args, **kwargs`

, will return a 3-tuple of:

- A
`ClosedJaxpr`

representing the evaluation of that function on those arguments. - A
`PyTree[jax.ShapeDtypeStruct]`

representing the output shape and dtype of the result. - A
`PyTree[Any]`

representing any non-array outputs from`fun`

.

The example arguments to be traced may be anything with `.shape`

and `.dtype`

fields (typically JAX arrays, NumPy arrays, of `jax.ShapeDtypeStruct`

s). All
other argments are treated statically. In particular, Python builtins (`bool`

,
`int`

, `float`

, `complex`

) are treated as static inputs; wrap them in JAX/NumPy
arrays if you would like them to be traced.

####
`equinox.filter_eval_shape(fun: Callable[..., Any], *args, **kwargs) -> PyTree[typing.Union[jax.ShapeDtypeStruct, typing.Any]]`

¤

As `jax.eval_shape`

, but allows any Python object as inputs and outputs.

(`jax.eval_shape`

is constrained to only work with JAX arrays, Python
float/int/etc.)

####
`equinox.filter_shard(x: PyTree[typing.Any], device_or_shardings: Union[jaxlib.xla_extension.Device, jaxlib.xla_extension.Sharding])`

¤

Filtered transform combining `jax.lax.with_sharding_constraint`

and `jax.device_put`

.

Enforces sharding within a JIT'd computation (That is, how an array is
split between multiple devices, i.e. multiple GPUs/TPUs.), or moves `x`

to
a device.

**Arguments:**

`x`

: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. They will have their shardings constrained.`device_or_shardings`

: Either a singular device (e.g. CPU or GPU) or PyTree of sharding specifications. The structure should be a prefix of`x`

.

**Returns:**

A copy of `x`

with the specified sharding constraints.

Example

See also the autoparallelism example.

## Automatic differentiation¤

####
`equinox.filter_grad(fun = sentinel, *, has_aux: bool = False)`

¤

Creates a function that computes the gradient of `fun`

.

The gradient will be computed with respect to all floating-point JAX/NumPy arrays in the first argument. (Which should be a PyTree.)

Any nondifferentiable leaves in the first argument will have `None`

as the gradient.

**Arguments:**

`fun`

is a pure function to differentiate.`has_aux`

: if`True`

then`fun`

should return a pair; the first element is the output to be differentiated and the second element is auxiliary data.

**Returns:**

A function with the same arguments as `fun`

, that computes the derivative of `fun`

with respect to its first input. Any nondifferentiable leaves will have `None`

as
the gradient.

If `has_aux`

is `True`

then a pair `(gradient, aux)`

is returned. If `has_aux`

is
`False`

then just the `gradient`

is returned.

Tip

If you need to differentiate multiple objects, then put them together into a tuple and pass that through the first argument:

```
# We want to differentiate `func` with respect to both `x` and `y`.
def func(x, y):
...
@equinox.filter_grad
def grad_func(x__y):
x, y = x__y
return func(x, y)
```

Info

See also `equinox.apply_updates`

for a convenience function that applies
non-`None`

gradient updates to a model.

####
`equinox.filter_value_and_grad(fun = sentinel, *, has_aux: bool = False) -> Callable`

¤

Creates a function that evaluates both `fun`

and the gradient of `fun`

.

The gradient will be computed with respect to all floating-point JAX/NumPy arrays in the first argument. (Which should be a PyTree.)

Any nondifferentiable leaves in the first argument will have `None`

as the gradient.

**Arguments:**

`fun`

is a pure function to differentiate.`has_aux`

: if`True`

then`fun`

should return a pair; the first element is the output to be differentiated and the second element is auxiliary data.

**Returns:**

A function with the same arguments as `fun`

, that evaluates both `fun`

and computes
the derivative of `fun`

with respect to its first input. Any nondifferentiable
leaves will have `None`

as the gradient.

If `has_aux`

is `True`

then a nested tuple `((value, aux), gradient)`

is returned.
If `has_aux`

is `False`

then the pair `(value, gradient)`

is returned.

####
`equinox.filter_jvp(fn: Callable[..., ~_T], primals: Sequence, tangents: Sequence, **kwargs) -> tuple[~_T, PyTree]`

¤

Like `jax.jvp`

, but accepts arbitrary PyTrees. (Not just JAXable types.)

In the following, an "inexact arraylike" refers to either a floating-point JAX
array, or a complex JAX array, or a Python `float`

, or a Python `complex`

. These are
the types which JAX considers to be differentiable.

**Arguments:**

`fn`

: Function to be differentiated. Its arguments can be Python objects, and its return type can be any Python object.`primals`

: The primal values at which`fn`

should be evaluated. Should be a sequence of arguments, and its length should be equal to the number of positional parameters of`fn`

.`tangents`

: The tangent vector for which the Jacobian-vector product should be calculated. Should be a PyTree with the same structure as`primals`

. The leaves of`tangents`

must either be inexact arraylikes, or they can be`None`

s.`None`

s are used to indicate (symbolic) zero tangents; in particular these must be passed for all primals that are not inexact arraylikes. (And`None`

can also be passed for any inexact arraylike primals too.)`**kwargs`

: Any keyword arguments to pass to`fn`

. These are not differentiated.

**Returns:**

A pair `(primals_out, tangents_out)`

is returned,
where `primals_out = fn(*primals)`

and `tangents_out`

is the Jacobian-vector
product of `fn`

evaluated at `primals`

with `tangents`

.

The `tangents_out`

has the same structure as `primals_out`

, but has `None`

for
any leaves with symbolic zero derivative. (Either because they're not
differentiable -- i.e. they're not a floating-point JAX array or Python `float`

--
or because they have no dependence on any input with non-symbolic-zero tangent.)

Tip

Unlike `jax.jvp`

, this function does not support a `has_aux`

argument. It isn't
needed, as unlike `jax.jvp`

the output of this function can be of arbitrary
type.

####
`equinox.filter_vjp(fun, *primals, *, has_aux: bool = False)`

¤

Like `jax.vjp`

, but accepts arbitrary PyTrees. (Not just JAXable types.)

**Arguments:**

`fun`

: The function to be differentiated. Will be called as`fun(*primals)`

. Can return an arbitrary PyTree.`primals`

: The arguments at which`fun`

will be evaluated and differentiated. Can be arbitrary PyTrees.`has_aux`

: Indicates whether`fun`

returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to`False`

.

**Returns:**

If `has_aux is False`

then returns a `(primals_out, vjpfun)`

pair, where
`primals_out = fun(*primals)`

and `vjpfun`

is a function from a cotangent vector
with the same shape as `primals_out`

to a tuple of cotangent vectors with the same
shape as `primals`

, representing the vector-Jacobian product of `fun`

evaluated at
`primals`

.

If `has_aux is True`

then returns a tuple `(primals_out, vjpfun, aux)`

, where `aux`

is the auxiliary data returned from `fun`

.

The cotangent passed to `vjpfun`

should have arrays corresponding to all
floating-point arrays in `primals_out`

, and `None`

for all other PyTree leaves. The
cotangents returned from `vjpfun`

will likewise have arrays for all `primals`

that
are floating-point arrays, and `None`

for all other PyTree leaves.

####
`equinox.filter_jacfwd(fun, has_aux: bool = False)`

¤

Computes the Jacobian of `fun`

, evaluated using forward-mode AD. The inputs and
outputs may be arbitrary PyTrees.

**Arguments:**

`fun`

: The function to be differentiated.`has_aux`

: Indicates whether`fun`

returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to`False`

.

**Returns:**

A function with the same arguments as `fun`

.

Warning

The outputs of `fun`

must be jax types, the filtering is only applied
to the input not the output.

If `has_aux is False`

then this function returns just the Jacobian of `fun`

with
respect to its first argument.

If `has_aux is True`

then it returns a pair `(jacobian, aux)`

, where `aux`

is the
auxiliary data returned from `fun`

.

####
`equinox.filter_jacrev(fun, has_aux: bool = False)`

¤

Computes the Jacobian of `fun`

, evaluated using reverse-mode AD. The inputs and
outputs may be arbitrary PyTrees.

**Arguments:**

`fun`

: The function to be differentiated.`has_aux`

: Indicates whether`fun`

returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to`False`

.

**Returns:**

A function with the same arguments as `fun`

.

Warning

The outputs of `fun`

must be jax types, the filtering is only applied
to the input not the output.

If `has_aux is False`

then this function returns just the Jacobian of `fun`

with
respect to its first argument.

If `has_aux is True`

then it returns a pair `(jacobian, aux)`

, where `aux`

is the
auxiliary data returned from `fun`

.

####
`equinox.filter_hessian(fun, has_aux: bool = False)`

¤

Computes the Hessian of `fun`

. The inputs and outputs may be arbitrary PyTrees.

**Arguments:**

`fun`

: The function to be differentiated.

**Returns:**

A function with the same arguments as `fun`

.

Warning

The outputs of `fun`

must be jax types, the filtering is only applied
to the input not the output.

If `has_aux is False`

then this function returns just the Hessian of `fun`

with
respect to its first argument.

If `has_aux is True`

then it returns a pair `(hessian, aux)`

, where `aux`

is the
auxiliary data returned from `fun`

.

####
`equinox.filter_custom_jvp(fn)`

¤

Filtered version of `jax.custom_jvp`

.

Works in the same way as `jax.custom_jvp`

, except that you do not need to specify
`nondiff_argnums`

. Instead, arguments are automatically split into differentiable
and nondifferentiable. (Everything that is not a floating-point array is necessarily
nondifferentiable. In addition, some floating-point arrays may happen not to have
been differentiated.)

The tangents of the nondifferentiable arguments will be passed as `None`

.

The return types must still all be JAX types.

Supports keyword arguments, which are always treated as nondifferentiable.

Example

```
@equinox.filter_custom_jvp
def call(x, y, *, fn):
return fn(x, y)
@call.def_jvp
def call_jvp(primals, tangents, *, fn):
x, y = primals
tx, ty = tangents
# `y` is not differentiated below, so it has a symbolic zero tangent,
# represented as a `None`.
assert ty is None
primal_out = call(x, y, fn=fn)
tangent_out = 2 * tx
return primal_out, tangent_out
x = jnp.array(2.0)
y = jnp.array(2.0)
fn = lambda a, b: a + b
# This only computes gradients for the first argument `x`.
equinox.filter_grad(call)(x, y, fn=fn)
```

####
`equinox.filter_custom_vjp(fn)`

¤

As `jax.custom_vjp`

, but with a nicer interface.

Usage is:

```
@equinox.filter_custom_vjp
def fn(vjp_arg, *args, **kwargs):
# `vjp_arg` is some PyTree of arbitrary Python objects.
# `args`, `kwargs` contain arbitrary Python objects.
...
return out # some PyTree of arbitrary Python objects.
@fn.def_fwd
def fn_fwd(perturbed, vjp_arg, *args, **kwargs):
# `perturbed` is a pytree with the same structure as `vjp_arg`. Every leaf is
# either `True` or `False`, indicating whether that leaf is being
# differentiated. (All leaves that are not floating-point arrays will
# necessarily have `False`. Some floating-point arrays might happen not to be
# differentiated either.)
...
# Should return `out` as before. `residuals` can be any collection of JAX
# arrays you want to keep around for the backward pass.
return out, residuals
@fn.def_bwd
def fn_bwd(residuals, grad_obj, perturbed, vjp_arg, *args, **kwargs):
# `grad_obj` will have `None` as the gradient for any leaves of `out` that were
# not differentiated.
...
# `grad_vjp_arg` should be a pytree with the same structure as `vjp_arg`.
# It can have `None` leaves to indicate that that argument has zero gradient.
# (E.g. if the leaf was not a JAX array.)
return grad_vjp_arg
```

The key differences to `jax.custom_vjp`

are that:

- Only the gradient of the first argument,
`vjp_arg`

, should be computed on the backward pass. Everything else will automatically have zero gradient. - You do not need to distinguish differentiable from nondifferentiable manually.
Instead you should return gradients for all perturbed arrays in the first
argument. (And just put
`None`

on every other leaf of the PyTree.) - As a convenience, all of the inputs from the forward pass are additionally made available to you on the backward pass.
- As a convenience, you can declare forward and backward passes using
`def_fwd`

and`def_bwd`

, rather than a single`defvjp`

as in core JAX.

Tip

If you need gradients with respect to multiple arguments, then just pack them
together as a tuple via the first argument `vjp_arg`

. (See also
`equinox.filter_grad`

for a similar trick.)

####
`equinox.filter_checkpoint(fun: Callable[~_P, ~_T] = sentinel, *, prevent_cse: bool = True, policy: Optional[Callable[..., bool]] = None) -> Callable[~_P, ~_T]`

¤

Filtered version of `jax.checkpoint`

.

Gradient checkpointing is a technique for reducing memory usage during
backpropagation, especially when used with reverse mode automatic differentiation
(e.g., `jax.grad`

or `equinox.filter_grad`

).

**Arguments:**

`fun`

: The function to be checkpointed. Will be called as`fun(*args, **kwargs)`

. Can return an arbitrary PyTree.`prevent_cse`

: If`True`

(the default), then JAX will not perform common subexpression elimination. Please see the documentation for`jax.checkpoint`

for more details.`policy`

: Callable for controlling which intermediate values should be rematerialized. It should be one of the attributes of`jax.checkpoint_policies`

.

####
`equinox.filter_closure_convert(fn: Callable[~_P, ~_T], *args, **kwargs) -> Callable[~_P, ~_T]`

¤

As `jax.closure_convert`

, but works on functions accepting and returning
arbitrary PyTree objects. In addition, all JAX arrays are hoisted into constants
(not just floating point arrays).

This is useful for explicitly capturing any closed-over JAX tracers
before crossing an API boundary, such as `jax.grad`

, `jax.custom_vjp`

, or the
rule of a custom primitive.

**Arguments:**

`fn`

: The function to call. Will be called as`fun(*args, **kwargs)`

.`args`

,`kwargs`

: Example arguments at which to call the function. The function is not actually evaluated on these arguments; all JAX arrays are subsituted for tracers. Note that Python builtins (`bool`

,`int`

,`float`

,`complex`

) are not substituted for tracers and are passed through as-is.

**Returns:**

A new function, which can be called in the same way, using `*args`

and `**kwargs`

.
Will contain all closed-over tracers of `fn`

as part of its PyTree structure.

Example

```
@jax.grad
def f(x, y):
z = x + y
g = lambda a: z + a # closes over z
g2 = filter_closure_convert(g, 1)
assert [id(b) for b in g2.consts] == [id(z)]
return z
f(1., 1.)
```

## Vectorisation and parallelisation¤

####
`equinox.filter_vmap(fun = sentinel, *, in_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), out_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), axis_name: Hashable = None, axis_size: Optional[int] = None)`

¤

Vectorises a function. By default, all JAX/NumPy arrays are vectorised down their leading axis (i.e. axis index 0), and all other types are broadcast.

**Arguments:**

For both `in_axes`

and `out_axes`

, then `int`

indicates an array axis to vectorise
over, `None`

indicates that an argument should be broadcast (not vectorised
over), and callables `Leaf -> Union[None, int]`

are mapped and evaluated on every
leaf of their subtree. `None`

should be used for non-JAX-array arguments.

`fun`

is a pure function to vectorise. Should be of the form`fun(*args)`

; that is to say it cannot accept keyword arguments.`in_axes`

indicates which axes of the input arrays should be vectorised over. It should be a PyTree of`None`

,`int`

, or callables`Leaf -> Union[None, int]`

. Its tree structure should either be:- a prefix of the input tuple of
`args`

. - a dictionary, in which case the named argments will use the specified indices
to vectorise over, and all other arguments will have the default
`eqx.if_array(0)`

.

- a prefix of the input tuple of
`out_axes`

indicates which axis of the output arrays the mapped axis should appear at. It should be a PyTree of`None`

,`int`

, or callables`Leaf -> Union[None, int]`

, and its tree structure should be a prefix of the output`fun(*args)`

.`axis_name`

is an optional hashable Python object used to identify the mapped axis so that parallel collectives (e.g.`jax.lax.psum`

) can be applied.`axis_size`

is an optional`int`

describing the size of the axis mapped. This only needs to be passed if none of the input arguments are vectorised, as else it can be deduced by looking at the argument shapes.

**Returns:**

The vectorised version of `fun`

.

Tip

To vectorise all JAX/NumPy arrays down their `j`

th axis, and broadcast all other
types, then you can use `equinox.if_array(j)`

, which returns a callable
`leaf -> j if is_array(leaf) else None`

. For example: the default values of
`in_axes`

and `out_axes`

are both `equinox.if_array(0)`

.

Example

```
import equinox as eqx
import jax.numpy as jnp
@eqx.filter_vmap
def f(x, y):
return x + y
@eqx.filter_vmap(in_axes=(None, 1))
def g(x, y):
return x + y
f(jnp.array([1, 2]), jnp.array([3, 4])) # both args vectorised down axis 0
f(jnp.array([1, 2]), 3) # first arg vectorised down axis 0
# second arg broadcasted
g(jnp.array(1), jnp.array([[2, 3]])) # first arg broadcasted
# second arg vectorised down axis 1
```

Example

`filter_vmap`

can be used to easily create ensembles of models. For example,
here's an ensemble of eight MLPs:

```
import equinox as eqx
import jax.random as jr
key = jr.PRNGKey(0)
keys = jr.split(key, 8)
# Create an ensemble of models
@eqx.filter_vmap
def make_ensemble(key):
return eqx.nn.MLP(2, 2, 2, 2, key=key)
mlp_ensemble = make_ensemble(keys)
# Evaluate each member of the ensemble on the same data
@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))
def evaluate_ensemble(model, x):
return model(x)
evaluate_ensemble(mlp_ensemble, jr.normal(key, (2,)))
# Evaluate each member of the ensemble on different data
@eqx.filter_vmap
def evaluate_per_ensemble(model, x):
return model(x)
evaluate_per_ensemble(mlp_ensemble, jr.normal(key, (8, 2)))
```

Here, `make_ensemble`

works because `equinox.nn.MLP`

is a PyTree, and so it
is a valid output from a `filter_vmap`

. This PyTree includes some JAX arrays
(the weights and biases) and some non-JAX-arrays (e.g. activation functions).
`filter_vmap`

will vectorise the JAX arrays (with separate weights for each
member of the ensemble) whilst leaving the non-JAX-arrays alone.

Note that as the weights in `mlp_ensemble`

now have a leading batch dimension
-- that the weights of `eqx.nn.MLP`

instances do not typically have -- then it
cannot be called directly. It must instead be passed back into a vectorised
region to be called.

####
`equinox.filter_pmap(fun = sentinel, *, in_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), out_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), axis_name: Hashable = None, axis_size: Optional[int] = None, donate: Literal['all', 'warn', 'none'] = 'none')`

¤

Warning

JAX has now added more powerful parallelism APIs directly to the JIT interface.
As such, using `equinox.filter_jit`

with sharded inputs is now recommended
over `filter_pmap`

. See also the
parallelism example.

Parallelises a function. By default, all JAX/NumPy arrays are parallelised down their leading axis (i.e. axis index 0), and all other types are broadcast.

`jax.pmap`

, and thus `equinox.filter_pmap`

, also compiles their function in the same
way as `jax.jit`

. By default, all JAX arrays are traced, and all other arguments are
treated as static inputs.

**Arguments:**

For both `in_axes`

and `out_axes`

, then `int`

indicates an array axis to parallelise
over, `None`

indicates that an argument should be broadcast (not parallelise
over), and callables `Leaf -> Union[None, int]`

are mapped and evaluated on every
leaf of their subtree. `None`

should be used for non-JAX-array arguments.

`fun`

is a pure function to parallelise. Should be of the form`fun(*args)`

; that is to say it cannot accept keyword arguments.`in_axes`

indicates which axes of the input arrays should be parallelised over. It should be a PyTree of`None`

,`int`

, or callables`Leaf -> Union[None, int]`

. Its tree structure should either be:- a prefix of the input tuple of
`args`

. - a dictionary, in which case the named argments will use the specified indices
to parallelise over, and all other arguments will have the default
`eqx.if_array(0)`

.

- a prefix of the input tuple of
`out_axes`

indicates which axis of the output arrays the mapped axis should appear at. It should be a PyTree of`None`

,`int`

, or callables`Leaf -> Union[None, int]`

, and its tree structure should be a prefix of the output`fun(*args)`

.`axis_name`

is an optional hashable Python object used to identify the mapped axis so that parallel collectives (e.g.`jax.lax.psum`

) can be applied.`axis_size`

is an optional`int`

describing the size of the axis mapped. This only needs to be passed if none of the input arguments are vectorised, as else it can be deduced by looking at the argument shapes.`donate`

indicates whether the buffers of JAX arrays are donated or not, it should either be:`'all'`

: donate all arrays and suppress all warnings about unused buffers;`'warn'`

: as above, but don't suppress unused buffer warnings;`'none'`

: the default, disables buffer donation.

**Returns:**

The parallelised version of `fun`

.

Tip

To parallelise all JAX/NumPy arrays down their `j`

th axis, and broadcast all
other types, then you can use `equinox.if_array(j)`

, which returns a callable
`leaf -> j if is_array(leaf) else None`

. For example: the default values of
`in_axes`

and `out_axes`

are both `equinox.if_array(0)`

.

Example

```
import equinox as eqx
import jax.numpy as jnp
@eqx.filter_pmap
def f(x, y):
return x + y
@eqx.filter_pmap(in_axes=(None, 1))
def g(x, y):
return x + y
f(jnp.array([1, 2]), jnp.array([3, 4])) # both args parallelised down axis 0
f(jnp.array([1, 2]), 3) # first arg parallelised down axis 0
# second arg broadcasted (as it's not
# a JAX array)
g(jnp.array(1), jnp.array([[2, 3]])) # first arg broadcasted
# second arg parallelised down axis 1
```

## Callbacks¤

####
`equinox.filter_pure_callback(callback, *args, *, result_shape_dtypes, vectorized = False, **kwargs)`

¤

Calls a Python function inside a JIT region. As `jax.pure_callback`

but accepts
arbitrary Python objects as inputs and outputs. (Not just JAXable types.)

**Arguments:**

`callback`

: The Python function to call.`args`

,`kwargs`

: The function will be called as`callback(*args, **kwargs)`

. These may be arbitrary Python objects.`result_shape_dtypes`

: A PyTree specifying the output of`callback`

. It should have a`jax.ShapeDtypeStruct`

in place of any JAX arrays.`vectorized`

: If`True`

then`callback`

is batched(when transformed by`vmap`

) by calling it directly on the batched arrays. If`False`

then`callback`

is called on each batch element individually.

**Returns:**

The result of `callback(*args, **kwargs)`

, valid for use under JIT.