Skip to content


equinox.tree_pformat(pytree: PyTree, width: int = 80, indent: int = 2, short_arrays: bool = True, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> str ¤

Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.

(This is the function used in __repr__ of equinox.Module.)

All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape.


A 32-bit floating-point JAX array of shape (3, 4) is printed as f32[3,4].


  • pytree: The PyTree to pretty-format.
  • width: The desired maximum number of characters per line of output. If a structure cannot be formatted within this constraint then a best effort will be made.
  • indent: The amount of indentation each nesting level.
  • short_arrays: Toggles the abbreviation of JAX arrays.
  • follow_wrapped: Whether to unwrap functools.partial and functools.wraps.
  • truncate_leaf: A function Any -> bool. Applied to all nodes in the PyTree; all truthy nodes will be truncated to just f"{type(node).__name__}(...)".


A string.

equinox.tree_equal(*pytrees: PyTree) -> Union[bool, numpy.bool_, Array] ¤

Returns True if all input PyTrees are equal. Every PyTree must have the same structure. Any JAX or NumPy arrays (as leaves) must have the same shape, dtype, and values to be considered equal. JAX arrays and NumPy arrays are considered equal to each other.


  • *pytrees: Any number of PyTrees each with any structure.


A boolean.

equinox.filter_pure_callback(callback, *args, *, result_shape_dtypes, vectorized = False, **kwargs) ¤

Calls a Python function inside a JIT region. As jax.pure_callback but accepts arbitrary Python objects as inputs and outputs. (Not just JAXable types.)


  • callback: The Python function to call.
  • args, kwargs: The function will be called as callback(*args, **kwargs). These may be arbitrary Python objects.
  • result_shape_dtypes: A PyTree specifying the output of callback. It should have a jax.ShapeDtypeStruct in place of any JAX arrays.
  • vectorized: If True then callback is batched(when transformed by vmap) by calling it directly on the batched arrays. If False then callback is called on each batch element individually.


The result of callback(*args, **kwargs), valid for use under JIT.

equinox.static_field(**kwargs) ¤

Used for marking that a field should not be treated as a leaf of the PyTree of a equinox.Module. (And is instead treated as part of the structure, i.e. as extra metadata.)


class MyModule(equinox.Module):
    normal_field: int
    static_field: int = equinox.static_field()

mymodule = MyModule("normal", "static")
leaves, treedef = jtu.tree_flatten(mymodule)
assert leaves == ["normal"]
assert "static" in str(treedef)

In practice this should rarely be used; it is usually preferential to just filter out each field with eqx.filter whenever you need to select only some fields.


  • **kwargs: If any are passed then they are passed on to datacalss.field. (Recall that Equinox uses dataclasses for its modules.)