{ "cells": [ { "cell_type": "markdown", "id": "42382be0", "metadata": {}, "source": [ "# Callbacks: composition over inheritance" ] }, { "cell_type": "markdown", "id": "fe1fa5ac", "metadata": {}, "source": [ "Reading:\n", "\n", "- [Python Design Patterns: The Composition Over Inheritance Principle](https://python-patterns.guide/gang-of-four/composition-over-inheritance/) by Brandon Rhodes\n", "- [Django Views the Right Way: Helpers vs mixins](https://spookylukey.github.io/django-views-the-right-way/common-context-data.html?highlight=mixin#discussion-helpers-vs-mixins) by Luke Plant" ] }, { "cell_type": "markdown", "id": "30ce2cf3", "metadata": {}, "source": [ "### Multiple inheritance (\"mixins\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bd704487", "metadata": {}, "outputs": [], "source": [ "class Module:\n", " def forward(self, x): return x ** 2" ] }, { "cell_type": "code", "execution_count": null, "id": "6aaa3405", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "100" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Module().forward(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "182a19e6", "metadata": {}, "outputs": [], "source": [ "class Module:\n", " def forward(self, x):\n", " self.x = x\n", " self.before_forward()\n", " self.y = x ** 2\n", " self.after_forward()\n", " return self.y\n", "\n", " def before_forward(self): pass\n", " def after_forward(self): pass" ] }, { "cell_type": "code", "execution_count": null, "id": "7ee2dabb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "100" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Module().forward(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "08d10922", "metadata": {}, "outputs": [], "source": [ "class LoggingMixin:\n", " def before_forward(self):\n", " print(f'{self.x=}')\n", " super().before_forward()\n", "\n", " def after_forward(self):\n", " print(f'{self.y=}')\n", " super().after_forward()" ] }, { "cell_type": "code", "execution_count": null, "id": "0146c7d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.x=10\n", "self.y=100\n" ] }, { "data": { "text/plain": [ "100" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MyModule(LoggingMixin, Module): pass\n", "MyModule().forward(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "de7c3626", "metadata": {}, "outputs": [], "source": [ "from torch import tensor\n", "\n", "class TensorMixin(Module):\n", " def before_forward(self):\n", " self.x = tensor(self.x)\n", " super().before_forward()\n", "\n", " def after_forward(self):\n", " self.y = tensor(self.y)\n", " super().after_forward()" ] }, { "cell_type": "code", "execution_count": null, "id": "4e028850", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.x=tensor(10)\n", "self.y=tensor(100)\n" ] }, { "data": { "text/plain": [ "tensor(100)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MyModule(TensorMixin, LoggingMixin, Module): pass\n", "MyModule().forward(10)" ] }, { "cell_type": "markdown", "id": "61366f1c", "metadata": {}, "source": [ "### Callbacks" ] }, { "cell_type": "code", "execution_count": null, "id": "7be468a8", "metadata": {}, "outputs": [], "source": [ "class Module:\n", " def __init__(self, cbs):\n", " self.cbs = cbs\n", " for cb in cbs: cb.mod = self\n", "\n", " def forward(self, x):\n", " self.x = x\n", " self.callback('before_forward')\n", " self.y = x ** 2\n", " self.callback('after_forward')\n", " return self.y\n", "\n", " def callback(self, nm):\n", " for cb in self.cbs: getattr(cb, nm, lambda o: None)()" ] }, { "cell_type": "code", "execution_count": null, "id": "9ff00775", "metadata": {}, "outputs": [], "source": [ "class LoggingCB:\n", " def before_forward(self): print(f'{self.mod.x=}')\n", " def after_forward(self): print(f'{self.mod.y=}')" ] }, { "cell_type": "code", "execution_count": null, "id": "484f1530", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.mod.x=10\n", "self.mod.y=100\n" ] }, { "data": { "text/plain": [ "100" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Module([LoggingCB()]).forward(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "246b5686", "metadata": {}, "outputs": [], "source": [ "class TensorCB:\n", " def before_forward(self): self.mod.x = tensor(self.mod.x)\n", " def after_forward(self): self.mod.y = tensor(self.mod.y)" ] }, { "cell_type": "code", "execution_count": null, "id": "b9935e6a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.mod.x=tensor(10)\n", "self.mod.y=tensor(100)\n" ] }, { "data": { "text/plain": [ "tensor(100)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Module([TensorCB(), LoggingCB()]).forward(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "064b4254", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }