Abstract/final rules¤
The abstract/final design pattern
Every class must be either:
(a) abstract (it can be subclassed, but not instantiated); or
(b) final (it can be instantiated, but not subclassed).
In addition:
(a) class instances must not be mutated.
(b) only abstractmethods and abstract attributes can be overridden. (So once they've been implemented, then a subclass may not override them.)
(c) abstract classes must have a name beginning with 'Abstract' or '_Abstract'.
(d) only concrete classes can define __init__ methods.
(e) never use super().
In principle you could pay close attention and manually follow this style of (functional) programming. As a practical matter, structional.Struct exists to enforce these rules of its subclasses. In the codebases I write, 99% of classes are either Structs or part of a domain-specific hierarchy (torch.nn.Module, pydantic.BaseModel). In particular anything that might be 'just a class' (or 'just a dataclass') will be a Struct.
Let's see how this design pattern is arrived at.
Level 1: Abstract base classes (ABCs) as interfaces¤
Let's start off with something very standard: using ABCs to define interfaces. Here's an example.
class AbstractOptimiser(structional.Struct):
@abc.abstractmethod
def init(self, model): ...
@abc.abstractmethod
def update(self, model, grads, state): ...
class Adam(AbstractOptimiser):
learning_rate: float
beta1: float = 0.9
beta2: float = 0.999
def init(self, model):
... # some implementation
return initial_state
def update(self, model, grads, state):
... # some implementation
return new_model, new_state
def make_step(model, data, opt_state, optimiser: AbstractOptimiser):
grads = ... # some backpropagation
return optimiser.update(model, grads, opt_state)
def train(model, dataloader, optimiser: AbstractOptimiser):
opt_state = optimiser.init(model)
for data in dataloader:
model, opt_state = make_step(model, data, opt_state, optimiser)
return model
model = ... # some model
dataloader = ... # some dataloader
optimiser = Adam(learning_rate=3e-4)
train(model, dataloader, optimiser)
Hopefully the above is indeed easy to read. The AbstractOptimiser defines an interface using init(model) and update(model, grad, state). (And we should also add type annotations to their arguments, actually.)
Subsequently, the train and make_step functions can be written without needing to know exactly which optimiser has been passed. (We can later implement some other optimiser and use that in the same place.)
The above is very common. Indeed Python has a whole module, abc, for declaring such abc.abstractmethods.
Level 2: intermediate ABCS, abstract attributes, and __init__-only-once¤
Now let's move on to a natural extension to the above: intermediate ABCs, that introduce partial implementations.
class AbstractInterpolation(structional.Struct):
@abc.abstractmethod
def __call__(self, x: np.ndarray) -> np.ndarray: ...
class AbstractPolynomialInterpolation(AbstractInterpolation):
coeffs: structional.AbstractVar[np.ndarray]
def degree(self) -> int:
return len(self.coeffs)
def __call__(self, x: np.ndarray) -> np.ndarray:
return np.polyval(self.coeffs, x)
class CubicInterpolation(AbstractPolynomialInterpolation):
coeffs: np.ndarray
def __init__(self, ts: np.ndarray, xs: np.ndarray):
self.coeffs = ... # some implementation
in this case, the intermediate ABC AbstractPolynomialInterpolation implements the __call__ method. However, it isn't yet a concrete (non-abstract) class, as it introduces a new abstract variable coeffs – we need to wait until CubicInterpolation for that to be defined.
Using an abstract attribute (structional.AbstractVar) here means that we can write self.coeffs inside degree and __call__, and know that this is safe. Unless all abstract attributes are defined then Equinox won't allow us to instantiate the class.
Why didn't we just define AbstractPolynomialInterpolation.coeffs as a concrete field? (Just coeffs: np.ndarray.) Indeed we could have written this:
class AbstractPolynomialInterpolation(AbstractInterpolation):
coeffs: np.ndarray
def __init__(self, coeffs: np.ndarray):
self.coeffs = coeffs
def degree(self) -> int:
return len(self.coeffs)
def __call__(self, x: np.ndarray) -> np.ndarray:
return jnp.polyval(self.coeffs, x)
class CubicInterpolation(AbstractPolynomialInterpolation):
def __init__(self, ts: np.ndarray, xs: np.ndarray):
coeffs = ... # some implementation
super().__init__(coeffs)
but this is now much less readable: we've split up initialisation across two different classes. (What does it even mean to initialise an abstract class (AbstractPolynomialInterpolation) anyway?) This is a reliable source of bugs. Thus we arrive at the rule that the __init__ method must be defined on the final concrete class.
Level 3: implement methods precisely once, and concrete-means-final¤
Our "__init__ only once" rule means that __init__ is defined precisely once, is never overridden, and we never call super().__init__. Why stop there: perhaps we should enforce that we never override any method?
In practice, we argue that's a good idea! This rule means that when you see code like:
def foo(interp: AbstractPolynomialInterpolation):
... = interp.degree()
you know that it is calling precisely AbstractPolynomialInterpolation.degree, and not an override in some subclass. This is excellent for code readability. Thus we get the rule that no method should be overridden.
If we assume this, then we now find ourselves arriving at a conclusion: concrete means final. That is, once we have a concrete class (every abstract method/attribute defined in our ABCs is now overridden with an implementation, so we can instantiate this class), then it is now final (we're not allowed to re-override things, so subclassing is pointless). This is how we arrive at the important abstract-or-final rule itself.
What about when you have an existing concrete class that you want to tweak just-a-little-bit? In this case, prefer composition over inheritance. Write a wrapper that forwards each method as appropriate. This is just as expressive, and means we keep these readable type-safe rules.
Level 4: __check_init__¤
It's pretty common to want to validate that certain invariants hold, even in abstract base classes. For this, we have the __check_init__ method:
class AbstractPolynomialInterpolation(AbstractInterpolation):
coeffs: structional.AbstractVar[np.ndarray]
def __check_init__(self):
if not np.issubdtype(self.coeffs.dtype, np.floating):
raise ValueError("Coefficients must be floating-point!")
...
This method is something that structional.Struct will look for. After initialisation, the __check_init__ method of all parent classes will be ran.
Extensions and FAQ¤
Does this pattern work with multiple inheritance?
Yes. For example, here's a diamond inheritance pattern (for building a differential equation solver):
class AbstractSolver(structional.Struct):
@abc.abstractmethod
def step(...): ...
class AbstractAdaptiveSolver(AbstractSolver):
tolerance: structional.AbstractVar[float]
class AbstractImplicitSolver(AbstractSolver):
root_finder: structional.AbstractVar[AbstractRootFinder]
class ImplicitEuler(AbstractAdaptiveSolver, AbstractImplicitSolver):
tolerance: float
root_finder: AbstractRootFinder = Newton()
def step(...):
... # some implementation
solver = ImplicitEuler(tolerance=1e-3)
That's a lot of Abstracts
Yes.
Does super() ever get used at all?
No. This design pattern means that you should never need to write super() at all.
What about co-operative multiple inheritance?
If you're a Python nerd, you'll now be wondering about co-operative multiple inheritance, which specifies using super() ubiquitously.
The TL;DR of this is that almost no-one ever uses this properly, and the abstract+final pattern is intended as a direct alternative. One sees a lot of code that looks like this:
class A:
def __init__(self, x):
self.x = x
# Not calling super().__init__, because the superclass is just `object`, right?
class AA:
def __init__(...):
super().__init__(...) # Being a good citizen.
... # Do anything else that needs to happen.
class B(A, AA):
pass
B() # AA.__init__ is not called.
In this case B() calls A.__init__ and this then fails to call AA.__init__. Co-operative multiple inheritance only works if everyone, well, co-operates.
Even if everyone wants to do their best, there is another issue. When writing super().__init__, it isn't actually known what method is being called – as above, super() could be pointing at almost any class at all. This actually means that it's not possible to know what arguments to pass to super().__init__! "Only use keyword arguments" is the closest to a resolution that this issue has, and it's still fragile.
In contrast, our no-overriding and abstract-or-final rules means that we never come across this scenario. We always know precisely what is being called.
Hang on, I don't buy the abstract-or-concrete part. Can't a concrete subclass add a new method to a concrete superclass?
You're thinking of something like this this:
# This is clearly something we can instantiate: it has no abstract methods/attributes.
class ConcreteArray(structional.Struct):
def some_method(self):
pass
# This is clearly also without abstract methods/attributes, and also doesn't break the
# rule about overriding methods.
class ConcreteArrayTwo(ConcreteArray):
def another_method(self):
pass
We didn't discuss it above, but we do ban things like this as well. The reason is to simplify things when writing something like:
def add(x: ConcreteArray, y: ConcreteArray) -> ConcreteArray:
...
so that there are never any questions about what the return type should be. (If we passed in ConcreteArrayTwo in to x or y, maybe we should try to return a ConcreteArrayTwo instead? What if x = ConcreteArrayTwo() but y = ConcreteArrayThree(), and these two types don't know about each other? Better to avoid the question in the first place.)
These ideas have appeared in <XYZ language>?
Yup! Variants of this design pattern are very common, especially in modern languages like Julia/Rust/etc., or in older languages with a strong emphasis on typing.
Any other advice?
This approach is about going all-in on nominal subtyping, and not structural subtyping.
Correspondingly, don't use hasattr or typing.Protocols. Use isinstance and ABCs instead.