Skip to content

FAQ¤

How do Structs integrate with PyTorch torch.nn.Modules?¤

PyTorch modules, despite looking very similar to trees and structional.Structs, are pretty much the antithesis of structional! This is because they are mutable, and are based around a style of object-oriented programming. For example, every gradient update mutates the same object.

Correspondingly, they are very deliberately not one of the types we treat as trees. Instead, they should be treated as an implementation detail, with Structs instead holding primacy.

For example:

from structional import AbstractVar, PRNGKey, Struct
from jaxtyping import Float  # tensor annotations, https://github.com/patrick-kidger/jaxtyping
from torch import from_numpy, nn, Tensor

class AbstractImageDiffusion(Struct):
    shape: AbstractVar[tuple[int, ...]]
    num_steps: AbstractVar[int]

    @abstractmethod
    def step(self, image: Float[Tensor, "*shape"]) -> Float[Tensor, "*shape"]: ...

class LinearImageDiffusion(AbstractImageDiffusion):  # State-of-the-art architecture.
    model: nn.Linear
    shape: tuple[int, ...]
    num_steps: int

    def __init__(self):
        self.model = nn.Linear(256*256, 256*256)
        self.shape = (256, 256)
        self.num_steps = 10

    def step(self, image: Float[Tensor, "*shape"]) -> Float[Tensor, "*shape"]:
        return self.model(image.reshape(-1)).reshape(self.shape)

def inference(model: AbstractImageDiffusion, key: PRNGKey) -> Float[Tensor, "*shape"]:
    x = key.normal(size=model.shape)  # Initial noise
    x = from_numpy(x)  # Zero-copy if we're on the CPU; efficient GPU is an exercise for the reader ;)
    for _ in range(model.num_steps):  # Denoise
        x = model.step(x)
    return x

How should this be used alongside JAX/Equinox?¤

It shouldn't! This library is really just a port from JAX/Equinox, of all the things I find myself enjoying about that ecosystem:

  • Struct is a stricter version of equinox.Module.
  • PRNGKey is equivalent to jax.random.key.
  • everything in tree is an analogue of JAX's pytrees. (And is not compatible, don't try to mix them.)
    • tree.map is equivalent to jax.tree.map
    • tree.replace is equivalent to equinox.tree_at

But if you're using JAX+Equinox, then you can pretty much just use their existing functionality already!

Should Structs be thought of as functional programing (FP) or object-oriented programming (OOP)?¤

Functional programming. Whilst both FP and OOP use structured types, the difference is essentially that in FP they are immutable (and we create new objects if we need to change them) whilst in OOP they are mutable (and modified through their methods). In our case, Structs are immutable.

What are similar reference points amongst other languages?¤

Structs combine Rust-style traits with Julia-style abstract/final.

Meanwhile in Python, we have several JAX/Equinox reference points:

  • Structs are inspired by equinox.Module;
  • PRNGKey is inspired by jax.random.key;
  • tree.map is inspired by jax.tree.map;
  • tree.replace is inspired by equinox.tree_at.

(Broadly speaking these ideas are pretty standard in functional programming.)

Aren't there a lot of 'functional Python' libraries?¤

Yup. Most of them tend to either implement Haskell-isms (foldl, monads) or use all those complicated functional programming words (functors, ...) 😄. In contrast structional is more about introducing an 'opinionated type system for Python' (Structs), which is a bit different.

Where should abstract base classes live?¤

When organizing code that's large enough to be split into multiple files, then ABCs could either be placed alongside their subclasses:

# bar.py
class AbstractFoo(Struct): ...

class ConcreteFoo(AbstractFoo): ...

# qux.py
def frobnicate(x: AbstractFoo): ...

or they could live alongside their consumers:

# baz.py
class AbstractFoo(Struct): ...

def frobnicate(x: AbstractFoo): ...

# fizzle.py
class ConcreteFoo(AbstractFoo): ...

There's no hard rule, but about 80% of the time I find it's more useful to keep the abstract class (AbstractFoo) next to its consumer (frobnicate). The other 20% of the time I find it's most useful to keep it next to its subclasses (ConcreteFoo).

The rationale for this is that typically it is the consumer (frobnicate) that is the 'first class citizen' of our code, and the ABC mostly just exists as a way to define the interface required by that consumer.

This approach also pairs well with how ABCs are often used to support extensibility: later authors can come along and define their own concrete subclasses.

This isn't a hard-and-fast rule. It matters most when the two files live in separately-versioned packages, and in this case it really is usually best to keep the ABC alongside its consumer.