Skip to content

Trees¤

Nested data structures are ubiquitous – for example, the config for a machine learning model. Manipulating these – mapping over them, or replacing piece – is essential, and here we provide facilities for doing so.

We refer to such objects as 'trees', which are specifically any nested collection of lists, tuples, namedtuples, dictionaries, functools.partials, methods, function closures, and subclasses of structional.Struct. That last one is particularly important: this is typically how custom tree structures are obtained.

How do these integrate with PyTorch torch.nn.Modules?

PyTorch modules, despite looking very similar to trees and structional.Structs, are pretty much the antithesis of structional! This is because they are mutable, and are based around a style of object-oriented programming. For example, every gradient update mutates the same object.

Correspondingly, they are very deliberately not one of the types we treat as trees. Instead, they should be treated as an implementation detail, with Structs instead holding primacy.

For example:

from structional import AbstractVar, PRNGKey, Struct
from jaxtyping import Float  # tensor annotations, https://github.com/patrick-kidger/jaxtyping
from torch import from_numpy, nn, Tensor

class AbstractImageDiffusion(Struct):
    shape: AbstractVar[tuple[int, ...]]
    num_steps: AbstractVar[int]

    @abstractmethod
    def step(self, image: Float[Tensor, "*shape"]) -> Float[Tensor, "*shape"]: ...

class LinearImageDiffusion(AbstractImageDiffusion):  # State-of-the-art architecture.
    model: nn.Linear
    shape: tuple[int, ...]
    num_steps: int

    def __init__(self):
        self.model = nn.Linear(256*256, 256*256)
        self.shape = (256, 256)
        self.num_steps = 10

    def step(self, image: Float[Tensor, "*shape"]) -> Float[Tensor, "*shape"]:
        return self.model(image.reshape(-1)).reshape(self.shape)

def inference(model: AbstractImageDiffusion, key: PRNGKey) -> Float[Tensor, "*shape"]:
    x = key.normal(size=model.shape)  # Initial noise
    x = from_numpy(x)  # Zero-copy if we're on the CPU; efficient GPU is an exercise for the reader ;)
    for _ in range(model.num_steps):  # Denoise
        x = model.step(x)
    return x

structional.tree.Tree ¤

A type annotation for tree types. The type of the leaf may be indicated via Tree[LeafType], e.g. Tree[int].

Mapping¤

This is what the functional programming geeks like to call a 'functor'.

structional.tree.map(*xs: ~_T | structional.tree.Lens[~_T], fn: Callable[..., Any], is_leaf: None | Callable[[Any], bool] = None, structure_from: None | structional.tree.StructureFrom = None, with_path: bool = False) -> ~_T ¤

Given trees *xs with the same structure, this calls y_i = fn(*xs_i) for each leaf index i. The result will be a tree with the same structure and with the ith leaf having value y_i.

Nontrivial tree types are tuples, dictionaries, lists, Structs, functools.partial, methods, and function closures.

The order of iteration is deterministic. (This is the reason that sets are not treated as non-leaf types.)

Arguments:

  • *xs: The trees to operate over, as described above. Advanced usage: this may also be wrapped in a structional.tree.Lens, to specify only part of a tree structure to operate on. The rest of the tree will remain unchanged.
  • fn: The function to call on all leaves of the structure.
  • is_leaf: Optional callable. Will be called on each node of xs[0]. If the callable returns True, then this will be treated as leaf.
  • structure_from: What tree structure to map over. This is required if and only if multiple trees are provided, that is to say if len(xs) > 1.
    1. If structional.tree.StructureFrom.ALL is used, then we require that every one of xs to have the exact same structure, and will raise an error if not.
    2. If structional.tree.StructureFrom.FIRST, then the tree structure of the first tree xs[0] will be used. (For those coming from JAX, this is JAX's default.) In this case we require only that xs[0] be a tree-prefix for all of xs. For example, tree.map([1], [[2]], fn=fn) is valid, and it will return [fn(1, [2])].
    3. If structional.tree.StructureFrom.COMMON is used, then we will locate the largest common tree-prefix and apply fn to those.
  • with_path: If False (the default), then fn is called as fn(*xs) for each leaf. If True, then fn is called as fn(*xs, path=path), with path being a structional.tree.Path object that indicates the path to the current leaf.
structure_from=StructureFrom.COMMON

Note that using structure_from=StructureFrom.COMMON can easily hide mistakes in the tree structure. It should be used very rarely.

For example it means that

structional.tree.map(
    [(1,)], [(2, 3)],
    fn=fn,
    structure_from=structional.tree.StructureFrom.COMMON
)
returns [fn((1,), (2, 3))], despite the fact that (1,) and (2, 3) are trees with different structures.

Returns:

A tree-mapped fn over the provided trees.

Raises:

A structional.tree.StructureError if:

  • structure_from is StructureFrom.ALL and trees do not match exactly.
  • structure_from is StructureFrom.FIRST and x is not a tree-prefix for each of xs.
Some useful recipes
  1. To get just the leaves:

    leaves = []
    tree.map(tree, fn=leaves.append)
    

  2. To get just the structure:

    structure = tree.map(tree, fn=lambda _: None)
    

  3. To extend with a custom tree type:

    def my_map(x, *, fn):
        def wrapped_fn(x):
            if type(x) is MyTypeWithFooAttr:
                return MyTypeWithFooAttr(foo=tree.map(x.foo, fn=wrapped_fn))
            else:
                return fn(x)
        return tree.map(x, fn=wrapped_fn)
    

Function closures

Function closures are one of the more unusual, but important, kinds of nontrivial tree type. In particular this prevents a common footgun for a particular more advanced use-case:

def prints_arguments(fn):
    def wrapped(*args, **kwargs):
        print(f"Called with {args} and {kwargs}")
        return fn(*args, **kwargs)
    return wrapped

class Adder(Struct):
    x: int

    def __call__(self, y: int) -> int:
        return self.x + y

# This is a nontrivial tree.
adds_three = Adder(3)
# This is still a nontrivial tree.
adds_three_loud = prints_arguments(adds_three)

Meanwhile PyTorch modules are intentionally not a nontrivial tree type. PyTorch modules have inherently an OOP/mutable design, whilst trees are inherently based around a functional/immutable design. Too easy to footgun when mixing these approaches – better to explicitly be in just one of these two regimes.

structional.tree.StructureFrom ¤

Represents the possible ways in which structional.tree.map can choose the structure to map over.

Attributes:

  • ALL: require that all mapped trees have the exact same structure, and will raise an error if not.
  • FIRST: the first mapped tree will define the structure. Raise an error only if the first tree is not a prefix of the other trees.
  • COMMON: locate the largest common tree prefix and map over that. Never raise an error.

structional.tree.StructureError(Exception) ¤

Raised to indicate that a structional.tree.map was performed over incompatible structures.

Replacing¤

This is what the functional programming geeks like to call a 'lens'.

structional.tree.replace(x: structional.tree.Lens[~_Value], value: Any = sentinel, *, fn: None | Callable = None) -> ~_Value ¤

Replaces a value within a tree, leaving the rest of the tree unchanged.

Arguments:

  • x: the tree, and location within it, to update. These are represented jointly as a structional.tree.Lens.
  • value: the value to insert. Mutually exclusive with fn.
  • fn: a function to call on the old value, returning the new value. Mutually exclusive with value.

Returns:

The updated tree. The original tree is left unchanged.

structional.tree.get(x: structional.tree.Lens[~_Value]) -> Any ¤

Gets the value currently pointed at by a lens.

structional.tree.Lens(structional.Struct) ¤

Specifes a location within a tree structure, typically for specifying where tree.map or tree.replace should apply.

Supports .attribute_lookup and [item_lookup] to specify locations within the tree.

Example

lens = tree.Lens(some_tree).some_attr[0].foo["hello"]

# Then used as:
... = tree.map(lens, ...)
... = tree.replace(lens, ...)
__init__(tree: structional.tree.Tree, *trees: structional.tree.Tree) ¤

Arguments:

  • tree: the tree to specify the location within.
  • *trees: additional trees, for use when applying a function over multiple trees with tree.map.
__getattribute__(name: str) -> structional.tree.Lens ¤
__getitem__(item: Any) -> structional.tree.Lens ¤

Miscellaneous¤

structional.tree.num_leaves(tree: structional.tree.Tree) -> int ¤

Counts the number of leaves in a tree.

structional.tree.pformat_structure(tree: structional.tree.Tree) -> str ¤

Pretty-formats a tree's structure, using * to represent each leaf.

Example

x = [1, 2, (3, 4)]
structional.tree.pformat_structure(x)
# [*, *, (*, *)]

structional.tree.register_leaf(x: type) -> None ¤

Registers a type as always being a leaf. Example usage:

class Foo(structional.Struct):
    some_field: int

leaves = []
structional.tree.map(Foo(some_field=3), fn=leaves.append)
print(leaves)  # [3]

structional.tree.register_leaf(Foo)

leaves = []
structional.tree.map(Foo(some_field=3), fn=leaves.append)
print(leaves)  # [Foo(some_field=3)]

structional.tree.is_registered_leaf_type(cls: type) -> bool ¤

Whether register_leaf(cls) or register_leaf(some_parent_class) has been called.

structional.tree.Path(structional.Struct) ¤

Represents a path through a tree structure. Typically obtained by using tree.map(..., with_path=True).

Attributes:

  • pieces: Individual elements of the path. Each one is a structional.tree.Path.Piece

Class attributes:

  • Piece: union over the following types.
  • ListGetItem(item): index into a list.
  • TupleGetItem(item): index into a tuple.
  • NamedTupleGetAttr(name): attribute on a named tuple.
  • DictGetItem(item): lookup in a dictionary.
  • StructDict(): accessing the .__dict__ of a structional.Struct.
  • PartialFunc(): accessing the .func attribute of a functools.partial.
  • PartialArgs(): accessing the .args attribute of a functools.partial.
  • PartialKeywords(): accessing the .keywords attribute of a functools.partial.
  • MethodFunc(): accessing the .__func__ of a method.
  • MethodSelf(): accessing the .__self__ of a method.
  • ClosureItem(item): index in a function closure.
__init__(*pieces: structional.tree.Path.Piece) ¤

Creates a new Path out of the constituent pieces.

push(*pieces: structional.tree.Path.Piece) -> structional.tree.Path ¤

Returns a new Path with the extra piece(s) appended. The original path remains unchanged.

structional.tree.Static(structional.Struct) ¤

Wraps a value into an empty tree. For example, [] is an empty tree, and a tree.map over it changes nothing.

This class offers a way to build such an empty tree, but with arbitrary metadata (the wrapped value) attached.

Attributes:

  • value: the wrapped value.
__init__(value: ~_Value) ¤

Arguments:

  • value: the value to wrap.