Pretty printing¤
equinox.tree_pprint(pytree: PyTree, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> None
¤
Pretty-prints a PyTree as a string, whilst abbreviating JAX arrays.
All JAX arrays in the PyTree are condensed down to a short string representation of their dtype and shape.
Example
A 32-bit floating-point JAX array of shape (3, 4)
is printed as f32[3,4]
.
Arguments:
pytree
: The PyTree to pretty-print.width
: The desired maximum number of characters per line of output. If a structure cannot be formatted within this constraint then a best effort will be made.indent
: The amount of indentation each nesting level.short_arrays
: Toggles the abbreviation of JAX arrays.struct_as_array
: Whether to treatjax.ShapeDtypeStruct
s as arrays.follow_wrapped
: Whether to unwrapfunctools.partial
andfunctools.wraps
.truncate_leaf
: A functionAny -> bool
. Applied to all nodes in the PyTree; all truthy nodes will be truncated to justf"{type(node).__name__}(...)"
.
Returns:
Nothing. (The result is printed to stdout instead.)
equinox.tree_pformat(pytree: PyTree, *, width: int = 80, indent: int = 2, short_arrays: bool = True, struct_as_array: bool = False, follow_wrapped: bool = True, truncate_leaf: Callable[[PyTree], bool] = <function _false>) -> str
¤
Pretty-formats a PyTree as a string, whilst abbreviating JAX arrays.
(This is the function used in __repr__
of equinox.Module
.)
As equinox.tree_pprint
, but returns the string instead of printing it.