Skip to content


equinox.tree_pformat(pytree: PyTree, width: int = 80, indent: int = 2, short_arrays: bool = True, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> str ¤

Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.

(This is the function used in __repr__ of equinox.Module.)

All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape.


A 32-bit floating-point JAX array of shape (3, 4) is printed as f32[3,4].


  • pytree: The PyTree to pretty-format.
  • width: The desired maximum number of characters per line of output. If a structure cannot be formatted within this constraint then a best effort will be made.
  • indent: The amount of indentation each nesting level.
  • short_arrays: Toggles the abbreviation of JAX arrays.
  • follow_wrapped: Whether to unwrap functools.partial and functools.wraps.
  • truncate_leaf: A function Any -> bool. Applied to all nodes in the PyTree; all truthy nodes will be truncated to just f"{type(node).__name__}(...)".


A string.

equinox.tree_equal(*pytrees: PyTree) -> Union[bool, numpy.bool_, Array] ¤

Returns True if all input PyTrees are equal. Every PyTree must have the same structure. Any JAX or NumPy arrays (as leaves) must have the same shape, dtype, and values to be considered equal. JAX arrays and NumPy arrays are not considered equal to each other.


  • *pytrees: Any number of PyTrees each with any structure.


A boolean.

equinox.static_field(**kwargs) ¤

Used for marking that a field should not be treated as a leaf of the PyTree of a equinox.Module. (And is instead treated as part of the structure, i.e. as extra metadata.)


class MyModule(equinox.Module):
    normal_field: int
    static_field: int = equinox.static_field()

mymodule = MyModule("normal", "static")
leaves, treedef = jtu.tree_flatten(mymodule)
assert leaves == ["normal"]
assert "static" in str(treedef)

In practice this should rarely be used; it is usually preferential to just filter out each field with eqx.filter whenever you need to select only some fields.


  • **kwargs: If any are passed then they are passed on to datacalss.field. (Recall that Equinox uses dataclasses for its modules.)