Trees¤
Nested data structures are ubiquitous – for example, the config for a machine learning model. Manipulating these – mapping over them, or replacing piece – is essential, and here we provide facilities for doing so.
We refer to such objects as 'trees', which are specifically any nested collection of lists, tuples, namedtuples, dictionaries, functools.partials, methods, function closures, and subclasses of structional.Struct. That last one is particularly important: this is typically how custom tree structures are obtained.
How do these integrate with PyTorch torch.nn.Modules?
PyTorch modules, despite looking very similar to trees and structional.Structs, are pretty much the antithesis of structional! This is because they are mutable, and are based around a style of object-oriented programming. For example, every gradient update mutates the same object.
Correspondingly, they are very deliberately not one of the types we treat as trees. Instead, they should be treated as an implementation detail, with Structs instead holding primacy.
For example:
from structional import AbstractVar, PRNGKey, Struct
from jaxtyping import Float # tensor annotations, https://github.com/patrick-kidger/jaxtyping
from torch import from_numpy, nn, Tensor
class AbstractImageDiffusion(Struct):
shape: AbstractVar[tuple[int, ...]]
num_steps: AbstractVar[int]
@abstractmethod
def step(self, image: Float[Tensor, "*shape"]) -> Float[Tensor, "*shape"]: ...
class LinearImageDiffusion(AbstractImageDiffusion): # State-of-the-art architecture.
model: nn.Linear
shape: tuple[int, ...]
num_steps: int
def __init__(self):
self.model = nn.Linear(256*256, 256*256)
self.shape = (256, 256)
self.num_steps = 10
def step(self, image: Float[Tensor, "*shape"]) -> Float[Tensor, "*shape"]:
return self.model(image.reshape(-1)).reshape(self.shape)
def inference(model: AbstractImageDiffusion, key: PRNGKey) -> Float[Tensor, "*shape"]:
x = key.normal(size=model.shape) # Initial noise
x = from_numpy(x) # Zero-copy if we're on the CPU; efficient GPU is an exercise for the reader ;)
for _ in range(model.num_steps): # Denoise
x = model.step(x)
return x
structional.tree.Tree
¤
A type annotation for tree types. The type of the leaf may be indicated via
Tree[LeafType], e.g. Tree[int].
Mapping¤
This is what the functional programming geeks like to call a 'functor'.
structional.tree.map(*xs: ~_T | structional.tree.Lens[~_T], fn: Callable[..., Any], is_leaf: None | Callable[[Any], bool] = None, structure_from: None | structional.tree.StructureFrom = None, with_path: bool = False) -> ~_T
¤
Given trees *xs with the same structure, this calls y_i = fn(*xs_i) for each
leaf index i. The result will be a tree with the same structure and with the ith
leaf having value y_i.
Nontrivial tree types are tuples, dictionaries, lists, Structs,
functools.partial, methods, and function closures.
The order of iteration is deterministic. (This is the reason that sets are not treated as non-leaf types.)
Arguments:
*xs: The trees to operate over, as described above. Advanced usage: this may also be wrapped in astructional.tree.Lens, to specify only part of a tree structure to operate on. The rest of the tree will remain unchanged.fn: The function to call on all leaves of the structure.is_leaf: Optional callable. Will be called on each node ofxs[0]. If the callable returnsTrue, then this will be treated as leaf.structure_from: What tree structure to map over. This is required if and only if multiple trees are provided, that is to say iflen(xs) > 1.- If
structional.tree.StructureFrom.ALLis used, then we require that every one ofxsto have the exact same structure, and will raise an error if not. - If
structional.tree.StructureFrom.FIRST, then the tree structure of the first treexs[0]will be used. (For those coming from JAX, this is JAX's default.) In this case we require only thatxs[0]be a tree-prefix for all ofxs. For example,tree.map([1], [[2]], fn=fn)is valid, and it will return[fn(1, [2])]. - If
structional.tree.StructureFrom.COMMONis used, then we will locate the largest common tree-prefix and applyfnto those.
- If
with_path: IfFalse(the default), thenfnis called asfn(*xs)for each leaf. IfTrue, thenfnis called asfn(*xs, path=path), withpathbeing astructional.tree.Pathobject that indicates the path to the current leaf.
structure_from=StructureFrom.COMMON
Note that using structure_from=StructureFrom.COMMON can easily hide mistakes
in the tree structure. It should be used very rarely.
For example it means that
structional.tree.map(
[(1,)], [(2, 3)],
fn=fn,
structure_from=structional.tree.StructureFrom.COMMON
)
[fn((1,), (2, 3))], despite the fact that (1,) and (2, 3) are
trees with different structures.
Returns:
A tree-mapped fn over the provided trees.
Raises:
A structional.tree.StructureError if:
structure_fromisStructureFrom.ALLand trees do not match exactly.structure_fromisStructureFrom.FIRSTandxis not a tree-prefix for each ofxs.
Some useful recipes
-
To get just the leaves:
leaves = [] tree.map(tree, fn=leaves.append) -
To get just the structure:
structure = tree.map(tree, fn=lambda _: None) -
To extend with a custom tree type:
def my_map(x, *, fn): def wrapped_fn(x): if type(x) is MyTypeWithFooAttr: return MyTypeWithFooAttr(foo=tree.map(x.foo, fn=wrapped_fn)) else: return fn(x) return tree.map(x, fn=wrapped_fn)
Function closures
Function closures are one of the more unusual, but important, kinds of nontrivial tree type. In particular this prevents a common footgun for a particular more advanced use-case:
def prints_arguments(fn):
def wrapped(*args, **kwargs):
print(f"Called with {args} and {kwargs}")
return fn(*args, **kwargs)
return wrapped
class Adder(Struct):
x: int
def __call__(self, y: int) -> int:
return self.x + y
# This is a nontrivial tree.
adds_three = Adder(3)
# This is still a nontrivial tree.
adds_three_loud = prints_arguments(adds_three)
Meanwhile PyTorch modules are intentionally not a nontrivial tree type. PyTorch modules have inherently an OOP/mutable design, whilst trees are inherently based around a functional/immutable design. Too easy to footgun when mixing these approaches – better to explicitly be in just one of these two regimes.
structional.tree.StructureFrom
¤
Represents the possible ways in which structional.tree.map can choose the
structure to map over.
Attributes:
ALL: require that all mapped trees have the exact same structure, and will raise an error if not.FIRST: the first mapped tree will define the structure. Raise an error only if the first tree is not a prefix of the other trees.COMMON: locate the largest common tree prefix and map over that. Never raise an error.
structional.tree.StructureError(Exception)
¤
Raised to indicate that a structional.tree.map was performed over incompatible
structures.
Replacing¤
This is what the functional programming geeks like to call a 'lens'.
structional.tree.replace(x: structional.tree.Lens[~_Value], value: Any = sentinel, *, fn: None | Callable = None) -> ~_Value
¤
Replaces a value within a tree, leaving the rest of the tree unchanged.
Arguments:
x: the tree, and location within it, to update. These are represented jointly as astructional.tree.Lens.value: the value to insert. Mutually exclusive withfn.fn: a function to call on the old value, returning the new value. Mutually exclusive withvalue.
Returns:
The updated tree. The original tree is left unchanged.
structional.tree.get(x: structional.tree.Lens[~_Value]) -> Any
¤
Gets the value currently pointed at by a lens.
structional.tree.Lens(structional.Struct)
¤
Specifes a location within a tree structure, typically for specifying where
tree.map or tree.replace should apply.
Supports .attribute_lookup and [item_lookup] to specify locations within the
tree.
Example
lens = tree.Lens(some_tree).some_attr[0].foo["hello"]
# Then used as:
... = tree.map(lens, ...)
... = tree.replace(lens, ...)
__init__(tree: structional.tree.Tree, *trees: structional.tree.Tree)
¤
Arguments:
tree: the tree to specify the location within.*trees: additional trees, for use when applying a function over multiple trees withtree.map.
__getattribute__(name: str) -> structional.tree.Lens
¤
__getitem__(item: Any) -> structional.tree.Lens
¤
Miscellaneous¤
structional.tree.num_leaves(tree: structional.tree.Tree) -> int
¤
Counts the number of leaves in a tree.
structional.tree.pformat_structure(tree: structional.tree.Tree) -> str
¤
Pretty-formats a tree's structure, using * to represent each leaf.
Example
x = [1, 2, (3, 4)]
structional.tree.pformat_structure(x)
# [*, *, (*, *)]
structional.tree.register_leaf(x: type) -> None
¤
Registers a type as always being a leaf. Example usage:
class Foo(structional.Struct):
some_field: int
leaves = []
structional.tree.map(Foo(some_field=3), fn=leaves.append)
print(leaves) # [3]
structional.tree.register_leaf(Foo)
leaves = []
structional.tree.map(Foo(some_field=3), fn=leaves.append)
print(leaves) # [Foo(some_field=3)]
structional.tree.is_registered_leaf_type(cls: type) -> bool
¤
Whether register_leaf(cls) or register_leaf(some_parent_class) has been
called.
structional.tree.Path(structional.Struct)
¤
Represents a path through a tree structure. Typically obtained by using
tree.map(..., with_path=True).
Attributes:
pieces: Individual elements of the path. Each one is astructional.tree.Path.Piece
Class attributes:
Piece: union over the following types.ListGetItem(item): index into a list.TupleGetItem(item): index into a tuple.NamedTupleGetAttr(name): attribute on a named tuple.DictGetItem(item): lookup in a dictionary.StructDict(): accessing the.__dict__of astructional.Struct.PartialFunc(): accessing the.funcattribute of afunctools.partial.PartialArgs(): accessing the.argsattribute of afunctools.partial.PartialKeywords(): accessing the.keywordsattribute of afunctools.partial.MethodFunc(): accessing the.__func__of a method.MethodSelf(): accessing the.__self__of a method.ClosureItem(item): index in a function closure.
structional.tree.Static(structional.Struct)
¤
Wraps a value into an empty tree. For example, [] is an empty tree, and a
tree.map over it changes nothing.
This class offers a way to build such an empty tree, but with arbitrary metadata (the wrapped value) attached.
Attributes:
value: the wrapped value.
__init__(value: ~_Value)
¤
Arguments:
value: the value to wrap.