Pretty printing¤
equinox.tree_pprint(pytree: Any, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, truncate_leaf: Callable[[Any], bool] = <function _false>) -> None
¤
Pretty-prints a PyTree as a string, whilst abbreviating JAX arrays.
All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape.
Example
A 32-bit floating-point JAX array of shape (3, 4) is printed as f32[3,4].
Arguments:
pytree: The PyTree to pretty-print.width: The desired maximum number of characters per line of output. If a structure cannot be formatted within this constraint then a best effort will be made.indent: The amount of indentation each nesting level.short_arrays: Toggles the abbreviation of JAX arrays.struct_as_array: Whether to treatjax.ShapeDtypeStructs as arrays.truncate_leaf: A functionAny -> bool. Applied to all nodes in the PyTree; all truthy nodes will be truncated to justf"{type(node).__name__}(...)".
Returns:
Nothing. (The result is printed to stdout instead.)
equinox.tree_pformat(pytree: Any, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, truncate_leaf: Callable[[Any], bool] = <function _false>) -> str
¤
Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.
(This is the function used in __repr__ of equinox.Module.)
As equinox.tree_pprint, but returns the string instead of printing it.