# Callbacks: composition over inheritance

Reading:

- [Python Design Patterns: The Composition Over Inheritance Principle](https://python-patterns.guide/gang-of-four/composition-over-inheritance/) by Brandon Rhodes
- [Django Views the Right Way: Helpers vs mixins](https://spookylukey.github.io/django-views-the-right-way/common-context-data.html?highlight=mixin#discussion-helpers-vs-mixins) by Luke Plant

### Multiple inheritance ("mixins")

In [None]:
class Module:
 def forward(self, x): return x ** 2

In [None]:
Module().forward(10)

100

In [None]:
class Module:
 def forward(self, x):
 self.x = x
 self.before_forward()
 self.y = x ** 2
 self.after_forward()
 return self.y

 def before_forward(self): pass
 def after_forward(self): pass

In [None]:
Module().forward(10)

100

In [None]:
class LoggingMixin:
 def before_forward(self):
 print(f'{self.x=}')
 super().before_forward()

 def after_forward(self):
 print(f'{self.y=}')
 super().after_forward()

In [None]:
class MyModule(LoggingMixin, Module): pass
MyModule().forward(10)

self.x=10
self.y=100


100

In [None]:
from torch import tensor

class TensorMixin(Module):
 def before_forward(self):
 self.x = tensor(self.x)
 super().before_forward()

 def after_forward(self):
 self.y = tensor(self.y)
 super().after_forward()

In [None]:
class MyModule(TensorMixin, LoggingMixin, Module): pass
MyModule().forward(10)

self.x=tensor(10)
self.y=tensor(100)


tensor(100)

### Callbacks

In [None]:
class Module:
 def __init__(self, cbs):
 self.cbs = cbs
 for cb in cbs: cb.mod = self

 def forward(self, x):
 self.x = x
 self.callback('before_forward')
 self.y = x ** 2
 self.callback('after_forward')
 return self.y

 def callback(self, nm):
 for cb in self.cbs: getattr(cb, nm, lambda o: None)()

In [None]:
class LoggingCB:
 def before_forward(self): print(f'{self.mod.x=}')
 def after_forward(self): print(f'{self.mod.y=}')

In [None]:
Module([LoggingCB()]).forward(10)

self.mod.x=10
self.mod.y=100


100

In [None]:
class TensorCB:
 def before_forward(self): self.mod.x = tensor(self.mod.x)
 def after_forward(self): self.mod.y = tensor(self.mod.y)

In [None]:
Module([TensorCB(), LoggingCB()]).forward(10)

self.mod.x=tensor(10)
self.mod.y=tensor(100)


tensor(100)