equinox.tree_pprint(pytree: PyTree, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> None
Pretty-prints a PyTree as a string, whilst abbreviating JAX arrays.
All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape.
A 32-bit floating-point JAX array of shape
(3, 4) is printed as
pytree: The PyTree to pretty-print.
width: The desired maximum number of characters per line of output. If a structure cannot be formatted within this constraint then a best effort will be made.
indent: The amount of indentation each nesting level.
short_arrays: Toggles the abbreviation of JAX arrays.
struct_as_array: Whether to treat
jax.ShapeDtypeStructs as arrays.
follow_wrapped: Whether to unwrap
truncate_leaf: A function
Any -> bool. Applied to all nodes in the PyTree; all truthy nodes will be truncated to just
Nothing. (The result is printed to stdout instead.)