Skip to content

Pretty printing¤

equinox.tree_pprint(pytree: PyTree, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> None ¤

Pretty-prints a PyTree as a string, whilst abbreviating JAX arrays.

All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape.

Example

A 32-bit floating-point JAX array of shape (3, 4) is printed as f32[3,4].

Arguments:

  • pytree: The PyTree to pretty-print.
  • width: The desired maximum number of characters per line of output. If a structure cannot be formatted within this constraint then a best effort will be made.
  • indent: The amount of indentation each nesting level.
  • short_arrays: Toggles the abbreviation of JAX arrays.
  • struct_as_array: Whether to treat jax.ShapeDtypeStructs as arrays.
  • follow_wrapped: Whether to unwrap functools.partial and functools.wraps.
  • truncate_leaf: A function Any -> bool. Applied to all nodes in the PyTree; all truthy nodes will be truncated to just f"{type(node).__name__}(...)".

Returns:

Nothing. (The result is printed to stdout instead.)


equinox.tree_pformat(pytree: PyTree, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> str ¤

Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.

(This is the function used in __repr__ of equinox.Module.)

As equinox.tree_pprint, but returns the string instead of printing it.