Skip to content

Transformations¤

These offer an alternate (easier to use) API for JAX transformations.

For example, JAX uses jax.jit(..., static_argnums=...) to manually indicate which arguments should be treated dynamically/statically. Meanwhile equinox.filter_jit automatically treats all JAX/NumPy arrays dynamically, and everything else statically. Moreover, this is done at the level of individual PyTree leaves, so that unlike jax.jit, one argument can have both dynamic (array-valued) and static leaves.

Most users find that this is a simpler API when working with complicated PyTrees, such as are produced when using Equinox modules. But you can also still use Equinox with normal jax.jit etc. if you so prefer.

Just-in-time compilation¤

equinox.filter_jit(fun = sentinel, *, donate: Literal['all', 'all-except-first', 'warn', 'warn-except-first', 'none'] = 'none') ¤

An easier-to-use version of jax.jit. All JAX and NumPy arrays are traced, and all other types are held static.

Arguments:

  • fun is a pure function to JIT compile.
  • donate indicates whether the buffers of JAX arrays are donated or not. It should either be:
    • 'all': donate all arrays and suppress all warnings about unused buffers;
    • 'all-except-first': donate all arrays except for those in the first argument, and suppress all warnings about unused buffers;
    • 'warn': as above, but don't suppress unused buffer warnings;
    • 'warn-except-first': as above, but don't suppress unused buffer warnings;
    • 'none': no buffer donation. (This the default.)

Returns:

The JIT'd version of fun.

Example

# Basic behaviour
@eqx.filter_jit
def f(x, y):  # both args traced if arrays, static if non-arrays
    return x + y, x - y

f(jnp.array(1), jnp.array(2))  # both args traced
f(jnp.array(1), 2)  # first arg traced, second arg static
f(1, 2)  # both args static

Info

Donating arguments allows their underlying memory to be used in the computation. This can produce speed and memory improvements, but means that you cannot use any donated arguments again, as their underlying memory has been overwritten. (JAX will throw an error if you try to.)

Info

If you want to trace Python bool/int/float/complex as well then you can do this by wrapping them into a JAX array: jnp.asarray(x).

If you want to donate only some arguments then this can be done by setting filter_jit(donate="all-except-first") and then passing all arguments that you don't want to donate through the first argument. (Packing multiple values into a tuple if necessary.)


equinox.filter_make_jaxpr(fun: Callable[~_P, Any]) -> Callable[~_P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]] ¤

As jax.make_jaxpr, but accepts arbitrary PyTrees as input and output.

Arguments:

  • fun: The function fun(*arg, **kwargs) whose jaxpr is to be computed. Its positional and keyword arguments may be anything, as can its return value.

Returns:

A wrapped version of fun, that when applied to example arguments *args, **kwargs, will return a 3-tuple of:

  • A ClosedJaxpr representing the evaluation of that function on those arguments.
  • A PyTree[jax.ShapeDtypeStruct] representing the output shape and dtype of the result.
  • A PyTree[Any] representing any non-array outputs from fun.

The example arguments to be traced may be anything with .shape and .dtype fields (typically JAX arrays, NumPy arrays, of jax.ShapeDtypeStructs). All other arguments are treated statically. In particular, Python builtins (bool, int, float, complex) are treated as static inputs; wrap them in JAX/NumPy arrays if you would like them to be traced.


equinox.filter_eval_shape(fun: Callable[..., Any], *args, **kwargs) -> PyTree[typing.Union[jax.ShapeDtypeStruct, typing.Any]] ¤

As jax.eval_shape, but allows any Python object as inputs and outputs.

(jax.eval_shape is constrained to only work with JAX arrays, Python float/int/etc.)


equinox.filter_shard(x: PyTree[typing.Any], device_or_shardings: Union[jaxlib.xla_extension.Device, jaxlib.xla_extension.Sharding]) ¤

Filtered transform combining jax.lax.with_sharding_constraint and jax.device_put.

Enforces sharding within a JIT'd computation (That is, how an array is split between multiple devices, i.e. multiple GPUs/TPUs.), or moves x to a device.

Arguments:

  • x: A PyTree, with potentially a mix of arrays and non-arrays on the leaves. They will have their shardings constrained.
  • device_or_shardings: Either a singular device (e.g. CPU or GPU) or PyTree of sharding specifications. The structure should be a prefix of x.

Returns:

A copy of x with the specified sharding constraints.

Example

See also the autoparallelism example.

Automatic differentiation¤

equinox.filter_grad(fun = sentinel, *, has_aux: bool = False) ¤

Creates a function that computes the gradient of fun.

The gradient will be computed with respect to all floating-point JAX/NumPy arrays in the first argument. (Which should be a PyTree.)

Any nondifferentiable leaves in the first argument will have None as the gradient.

Arguments:

  • fun is a pure function to differentiate.
  • has_aux: if True then fun should return a pair; the first element is the output to be differentiated and the second element is auxiliary data.

Returns:

A function with the same arguments as fun, that computes the derivative of fun with respect to its first input. Any nondifferentiable leaves will have None as the gradient.

If has_aux is True then a pair (gradient, aux) is returned. If has_aux is False then just the gradient is returned.

Tip

If you need to differentiate multiple objects, then put them together into a tuple and pass that through the first argument:

# We want to differentiate `func` with respect to both `x` and `y`.
def func(x, y):
    ...

@equinox.filter_grad
def grad_func(x__y):
    x, y = x__y
    return func(x, y)

Info

See also equinox.apply_updates for a convenience function that applies non-None gradient updates to a model.


equinox.filter_value_and_grad(fun = sentinel, *, has_aux: bool = False) -> Callable ¤

Creates a function that evaluates both fun and the gradient of fun.

The gradient will be computed with respect to all floating-point JAX/NumPy arrays in the first argument. (Which should be a PyTree.)

Any nondifferentiable leaves in the first argument will have None as the gradient.

Arguments:

  • fun is a pure function to differentiate.
  • has_aux: if True then fun should return a pair; the first element is the output to be differentiated and the second element is auxiliary data.

Returns:

A function with the same arguments as fun, that evaluates both fun and computes the derivative of fun with respect to its first input. Any nondifferentiable leaves will have None as the gradient.

If has_aux is True then a nested tuple ((value, aux), gradient) is returned. If has_aux is False then the pair (value, gradient) is returned.


equinox.filter_jvp(fn: Callable[..., ~_T], primals: Sequence, tangents: Sequence, **kwargs) -> tuple[~_T, PyTree] ¤

Like jax.jvp, but accepts arbitrary PyTrees. (Not just JAXable types.)

In the following, an "inexact arraylike" refers to either a floating-point JAX array, or a complex JAX array, or a Python float, or a Python complex. These are the types which JAX considers to be differentiable.

Arguments:

  • fn: Function to be differentiated. Its arguments can be Python objects, and its return type can be any Python object.
  • primals: The primal values at which fn should be evaluated. Should be a sequence of arguments, and its length should be equal to the number of positional parameters of fn.
  • tangents: The tangent vector for which the Jacobian-vector product should be calculated. Should be a PyTree with the same structure as primals. The leaves of tangents must either be inexact arraylikes, or they can be Nones. Nones are used to indicate (symbolic) zero tangents; in particular these must be passed for all primals that are not inexact arraylikes. (And None can also be passed for any inexact arraylike primals too.)
  • **kwargs: Any keyword arguments to pass to fn. These are not differentiated.

Returns:

A pair (primals_out, tangents_out) is returned, where primals_out = fn(*primals) and tangents_out is the Jacobian-vector product of fn evaluated at primals with tangents.

The tangents_out has the same structure as primals_out, but has None for any leaves with symbolic zero derivative. (Either because they're not differentiable -- i.e. they're not a floating-point JAX array or Python float -- or because they have no dependence on any input with non-symbolic-zero tangent.)

Tip

Unlike jax.jvp, this function does not support a has_aux argument. It isn't needed, as unlike jax.jvp the output of this function can be of arbitrary type.


equinox.filter_vjp(fun, *primals, *, has_aux: bool = False) ¤

Like jax.vjp, but accepts arbitrary PyTrees. (Not just JAXable types.)

Arguments:

  • fun: The function to be differentiated. Will be called as fun(*primals). Can return an arbitrary PyTree.
  • primals: The arguments at which fun will be evaluated and differentiated. Can be arbitrary PyTrees.
  • has_aux: Indicates whether fun returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to False.

Returns:

If has_aux is False then returns a (primals_out, vjpfun) pair, where primals_out = fun(*primals) and vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fun evaluated at primals.

If has_aux is True then returns a tuple (primals_out, vjpfun, aux), where aux is the auxiliary data returned from fun.

The cotangent passed to vjpfun should have arrays corresponding to all floating-point arrays in primals_out, and None for all other PyTree leaves. The cotangents returned from vjpfun will likewise have arrays for all primals that are floating-point arrays, and None for all other PyTree leaves.


equinox.filter_jacfwd(fun, has_aux: bool = False) ¤

Computes the Jacobian of fun, evaluated using forward-mode AD. The inputs and outputs may be arbitrary PyTrees.

Arguments:

  • fun: The function to be differentiated.
  • has_aux: Indicates whether fun returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to False.

Returns:

A function with the same arguments as fun.

Warning

The outputs of fun must be jax types, the filtering is only applied to the input not the output.

If has_aux is False then this function returns just the Jacobian of fun with respect to its first argument.

If has_aux is True then it returns a pair (jacobian, aux), where aux is the auxiliary data returned from fun.


equinox.filter_jacrev(fun, has_aux: bool = False) ¤

Computes the Jacobian of fun, evaluated using reverse-mode AD. The inputs and outputs may be arbitrary PyTrees.

Arguments:

  • fun: The function to be differentiated.
  • has_aux: Indicates whether fun returns a pair, with the first element the output to be differentiated, and the latter auxiliary data. Defaults to False.

Returns:

A function with the same arguments as fun.

Warning

The outputs of fun must be jax types, the filtering is only applied to the input not the output.

If has_aux is False then this function returns just the Jacobian of fun with respect to its first argument.

If has_aux is True then it returns a pair (jacobian, aux), where aux is the auxiliary data returned from fun.


equinox.filter_hessian(fun, has_aux: bool = False) ¤

Computes the Hessian of fun. The inputs and outputs may be arbitrary PyTrees.

Arguments:

  • fun: The function to be differentiated.

Returns:

A function with the same arguments as fun.

Warning

The outputs of fun must be jax types, the filtering is only applied to the input not the output.

If has_aux is False then this function returns just the Hessian of fun with respect to its first argument.

If has_aux is True then it returns a pair (hessian, aux), where aux is the auxiliary data returned from fun.


equinox.filter_custom_jvp(fn) ¤

Filtered version of jax.custom_jvp.

Works in the same way as jax.custom_jvp, except that you do not need to specify nondiff_argnums. Instead, arguments are automatically split into differentiable and nondifferentiable. (Everything that is not a floating-point array is necessarily nondifferentiable. In addition, some floating-point arrays may happen not to have been differentiated.)

The tangents of the nondifferentiable arguments will be passed as None.

The return types must still all be JAX types.

Supports keyword arguments, which are always treated as nondifferentiable.

Example

@equinox.filter_custom_jvp
def call(x, y, *, fn):
    return fn(x, y)

@call.def_jvp
def call_jvp(primals, tangents, *, fn):
    x, y = primals
    tx, ty = tangents
    # `y` is not differentiated below, so it has a symbolic zero tangent,
    # represented as a `None`.
    assert ty is None
    primal_out = call(x, y, fn=fn)
    tangent_out = 2 * tx
    return primal_out, tangent_out

x = jnp.array(2.0)
y = jnp.array(2.0)
fn = lambda a, b: a + b
# This only computes gradients for the first argument `x`.
equinox.filter_grad(call)(x, y, fn=fn)

equinox.filter_custom_vjp(fn) ¤

As jax.custom_vjp, but with a nicer interface.

Usage is:

@equinox.filter_custom_vjp
def fn(vjp_arg, *args, **kwargs):
    # `vjp_arg` is some PyTree of arbitrary Python objects.
    # `args`, `kwargs` contain arbitrary Python objects.
    ...
    return out  # some PyTree of arbitrary Python objects.

@fn.def_fwd
def fn_fwd(perturbed, vjp_arg, *args, **kwargs):
    # `perturbed` is a pytree with the same structure as `vjp_arg`. Every leaf is
    # either `True` or `False`, indicating whether that leaf is being
    # differentiated. (All leaves that are not floating-point arrays will
    # necessarily have `False`. Some floating-point arrays might happen not to be
    # differentiated either.)
    ...
    # Should return `out` as before. `residuals` can be any collection of JAX
    # arrays you want to keep around for the backward pass.
    return out, residuals

@fn.def_bwd
def fn_bwd(residuals, grad_obj, perturbed, vjp_arg, *args, **kwargs):
    # `grad_obj` will have `None` as the gradient for any leaves of `out` that were
    # not differentiated.
    ...
    # `grad_vjp_arg` should be a pytree with the same structure as `vjp_arg`.
    # It can have `None` leaves to indicate that that argument has zero gradient.
    # (E.g. if the leaf was not a JAX array.)
    return grad_vjp_arg

The key differences to jax.custom_vjp are that:

  • Only the gradient of the first argument, vjp_arg, should be computed on the backward pass. Everything else will automatically have zero gradient.
  • You do not need to distinguish differentiable from nondifferentiable manually. Instead you should return gradients for all perturbed arrays in the first argument. (And just put None on every other leaf of the PyTree.)
  • As a convenience, all of the inputs from the forward pass are additionally made available to you on the backward pass.
  • As a convenience, you can declare forward and backward passes using def_fwd and def_bwd, rather than a single defvjp as in core JAX.

Tip

If you need gradients with respect to multiple arguments, then just pack them together as a tuple via the first argument vjp_arg. (See also equinox.filter_grad for a similar trick.)


equinox.filter_checkpoint(fun: Callable[~_P, ~_T] = sentinel, *, prevent_cse: bool = True, policy: Optional[Callable[..., bool]] = None) -> Callable[~_P, ~_T] ¤

Filtered version of jax.checkpoint.

Gradient checkpointing is a technique for reducing memory usage during backpropagation, especially when used with reverse mode automatic differentiation (e.g., jax.grad or equinox.filter_grad).

Arguments:

  • fun: The function to be checkpointed. Will be called as fun(*args, **kwargs). Can return an arbitrary PyTree.
  • prevent_cse: If True (the default), then JAX will not perform common subexpression elimination. Please see the documentation for jax.checkpoint for more details.
  • policy: Callable for controlling which intermediate values should be rematerialized. It should be one of the attributes of jax.checkpoint_policies.

equinox.filter_closure_convert(fn: Callable[~_P, ~_T], *args, **kwargs) -> Callable[~_P, ~_T] ¤

As jax.closure_convert, but works on functions accepting and returning arbitrary PyTree objects. In addition, all JAX arrays are hoisted into constants (not just floating point arrays).

This is useful for explicitly capturing any closed-over JAX tracers before crossing an API boundary, such as jax.grad, jax.custom_vjp, or the rule of a custom primitive.

Arguments:

  • fn: The function to call. Will be called as fun(*args, **kwargs).
  • args, kwargs: Example arguments at which to call the function. The function is not actually evaluated on these arguments; all JAX arrays are substituted for tracers. Note that Python builtins (bool, int, float, complex) are not substituted for tracers and are passed through as-is.

Returns:

A new function, which can be called in the same way, using *args and **kwargs. Will contain all closed-over tracers of fn as part of its PyTree structure.

Example

@jax.grad
def f(x, y):
    z = x + y
    g = lambda a: z + a  # closes over z
    g2 = filter_closure_convert(g, 1)
    assert [id(b) for b in g2.consts] == [id(z)]
    return z

f(1., 1.)

Vectorisation and parallelisation¤

equinox.filter_vmap(fun = sentinel, *, in_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), out_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), axis_name: Hashable = None, axis_size: Optional[int] = None) ¤

Vectorises a function. By default, all JAX/NumPy arrays are vectorised down their leading axis (i.e. axis index 0), and all other types are broadcast.

Arguments:

For both in_axes and out_axes, then int indicates an array axis to vectorise over, None indicates that an argument should be broadcast (not vectorised over), and callables Leaf -> Union[None, int] are mapped and evaluated on every leaf of their subtree. None should be used for non-JAX-array arguments.

  • fun is a pure function to vectorise. Should be of the form fun(*args); that is to say it cannot accept keyword arguments.
  • in_axes indicates which axes of the input arrays should be vectorised over. It should be a PyTree of None, int, or callables Leaf -> Union[None, int]. Its tree structure should either be:
    1. a prefix of the input tuple of args.
    2. a dictionary, in which case the named arguments use the specified indices to vectorise over, and all other arguments will have the default eqx.if_array(0).
  • out_axes indicates which axis of the output arrays the mapped axis should appear at. It should be a PyTree of None, int, or callables Leaf -> Union[None, int], and its tree structure should be a prefix of the output fun(*args).
  • axis_name is an optional hashable Python object used to identify the mapped axis so that parallel collectives (e.g. jax.lax.psum) can be applied.
  • axis_size is an optional int describing the size of the axis mapped. This only needs to be passed if none of the input arguments are vectorised, as else it can be deduced by looking at the argument shapes.

Returns:

The vectorised version of fun.

Tip

To vectorise all JAX/NumPy arrays down their jth axis, and broadcast all other types, then you can use equinox.if_array(j), which returns a callable leaf -> j if is_array(leaf) else None. For example: the default values of in_axes and out_axes are both equinox.if_array(0).

Example

import equinox as eqx
import jax.numpy as jnp

@eqx.filter_vmap
def f(x, y):
    return x + y

@eqx.filter_vmap(in_axes=(None, 1))
def g(x, y):
    return x + y

f(jnp.array([1, 2]), jnp.array([3, 4]))  # both args vectorised down axis 0

f(jnp.array([1, 2]), 3)                  # first arg vectorised down axis 0
                                         # second arg broadcasted

g(jnp.array(1), jnp.array([[2, 3]]))     # first arg broadcasted
                                         # second arg vectorised down axis 1

Example

filter_vmap can be used to easily create ensembles of models. For example, here's an ensemble of eight MLPs:

import equinox as eqx
import jax.random as jr

key = jr.PRNGKey(0)
keys = jr.split(key, 8)

# Create an ensemble of models

@eqx.filter_vmap
def make_ensemble(key):
    return eqx.nn.MLP(2, 2, 2, 2, key=key)

mlp_ensemble = make_ensemble(keys)

# Evaluate each member of the ensemble on the same data

@eqx.filter_vmap(in_axes=(eqx.if_array(0), None))
def evaluate_ensemble(model, x):
    return model(x)

evaluate_ensemble(mlp_ensemble, jr.normal(key, (2,)))

# Evaluate each member of the ensemble on different data

@eqx.filter_vmap
def evaluate_per_ensemble(model, x):
    return model(x)

evaluate_per_ensemble(mlp_ensemble, jr.normal(key, (8, 2)))

Here, make_ensemble works because equinox.nn.MLP is a PyTree, and so it is a valid output from a filter_vmap. This PyTree includes some JAX arrays (the weights and biases) and some non-JAX-arrays (e.g. activation functions). filter_vmap will vectorise the JAX arrays (with separate weights for each member of the ensemble) whilst leaving the non-JAX-arrays alone.

Note that as the weights in mlp_ensemble now have a leading batch dimension -- that the weights of eqx.nn.MLP instances do not typically have -- then it cannot be called directly. It must instead be passed back into a vectorised region to be called.


equinox.filter_pmap(fun = sentinel, *, in_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), out_axes: PyTree[typing.Union[int, NoneType, Callable[[typing.Any], typing.Optional[int]]]] = if_array(axis=0), axis_name: Hashable = None, axis_size: Optional[int] = None, donate: Literal['all', 'warn', 'none'] = 'none') ¤

Warning

JAX has now added more powerful parallelism APIs directly to the JIT interface. As such, using equinox.filter_jit with sharded inputs is now recommended over filter_pmap. See also the parallelism example.

Parallelises a function. By default, all JAX/NumPy arrays are parallelised down their leading axis (i.e. axis index 0), and all other types are broadcast.

jax.pmap, and thus equinox.filter_pmap, also compiles their function in the same way as jax.jit. By default, all JAX arrays are traced, and all other arguments are treated as static inputs.

Arguments:

For both in_axes and out_axes, then int indicates an array axis to parallelise over, None indicates that an argument should be broadcast (not parallelise over), and callables Leaf -> Union[None, int] are mapped and evaluated on every leaf of their subtree. None should be used for non-JAX-array arguments.

  • fun is a pure function to parallelise. Should be of the form fun(*args); that is to say it cannot accept keyword arguments.
  • in_axes indicates which axes of the input arrays should be parallelised over. It should be a PyTree of None, int, or callables Leaf -> Union[None, int]. Its tree structure should either be:
    1. a prefix of the input tuple of args.
    2. a dictionary, in which case the named arguments use the specified indices to parallelise over, and all other arguments will have the default eqx.if_array(0).
  • out_axes indicates which axis of the output arrays the mapped axis should appear at. It should be a PyTree of None, int, or callables Leaf -> Union[None, int], and its tree structure should be a prefix of the output fun(*args).
  • axis_name is an optional hashable Python object used to identify the mapped axis so that parallel collectives (e.g. jax.lax.psum) can be applied.
  • axis_size is an optional int describing the size of the axis mapped. This only needs to be passed if none of the input arguments are vectorised, as else it can be deduced by looking at the argument shapes.
  • donate indicates whether the buffers of JAX arrays are donated or not, it should either be:
    • 'all': donate all arrays and suppress all warnings about unused buffers;
    • 'warn': as above, but don't suppress unused buffer warnings;
    • 'none': the default, disables buffer donation.

Returns:

The parallelised version of fun.

Tip

To parallelise all JAX/NumPy arrays down their jth axis, and broadcast all other types, then you can use equinox.if_array(j), which returns a callable leaf -> j if is_array(leaf) else None. For example: the default values of in_axes and out_axes are both equinox.if_array(0).

Example

import equinox as eqx
import jax.numpy as jnp

@eqx.filter_pmap
def f(x, y):
    return x + y

@eqx.filter_pmap(in_axes=(None, 1))
def g(x, y):
    return x + y

f(jnp.array([1, 2]), jnp.array([3, 4]))  # both args parallelised down axis 0

f(jnp.array([1, 2]), 3)                  # first arg parallelised down axis 0
                                         # second arg broadcasted (as it's not
                                         # a JAX array)

g(jnp.array(1), jnp.array([[2, 3]]))     # first arg broadcasted
                                         # second arg parallelised down axis 1

Callbacks¤

equinox.filter_pure_callback(callback, *args, *, result_shape_dtypes, sharding = None, vmap_method = None, vectorized = None, **kwargs) ¤

Calls a Python function inside a JIT region. As jax.pure_callback but accepts arbitrary Python objects as inputs and outputs. (Not just JAXable types.)

Note that unlike jax.pure_callback, then the result_shape_dtypes argument must be passed as a keyword argument.

Arguments:

  • callback: The Python function to call.
  • *args, **kwargs: The function will be called as callback(*args, **kwargs). These may be arbitrary Python objects.
  • result_shape_dtypes: A PyTree specifying the output of callback. It should have a jax.ShapeDtypeStruct in place of any JAX arrays. Note that unlike jax.pure_callback, this must be passed as a keyword-only argument.
  • sharding: optional sharding that specifies the device from which the callback should be invoked.
  • vmap_method, vectorized: these specify how the callback transforms under vmap() as described in the documentation for jax.pure_callback.

Returns:

The result of callback(*args, **kwargs), valid for use under JIT.