{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import matplotlib.pyplot as plt\n", "import random" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callbacks as GUI events" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ipywidgets as widgets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the [ipywidget docs](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20Events.html):\n", "\n", "- *the button widget is used to handle mouse clicks. The on_click method of the Button can be used to register function to be called when the button is clicked*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w = widgets.Button(description='Click me')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2a8b2631c38d4f0fa35aeea61c9211ba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Button(description='Click me', style=ButtonStyle())" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "w" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def f(o): print('hi')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "w.on_click(f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*NB: When callbacks are used in this way they are often called \"events\".*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating your own callback" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from time import sleep" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def slow_calculation():\n", " res = 0\n", " for i in range(5):\n", " res += i*i\n", " sleep(1)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def slow_calculation(cb=None):\n", " res = 0\n", " for i in range(5):\n", " res += i*i\n", " sleep(1)\n", " if cb: cb(i)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_progress(epoch): print(f\"Awesome! We've finished epoch {epoch}!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Awesome! We've finished epoch 0!\n", "Awesome! We've finished epoch 1!\n", "Awesome! We've finished epoch 2!\n", "Awesome! We've finished epoch 3!\n", "Awesome! We've finished epoch 4!\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(show_progress)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Lambdas and partials" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Awesome! We've finished epoch 0!\n", "Awesome! We've finished epoch 1!\n", "Awesome! We've finished epoch 2!\n", "Awesome! We've finished epoch 3!\n", "Awesome! We've finished epoch 4!\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(lambda o: print(f\"Awesome! We've finished epoch {o}!\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_progress(exclamation, epoch): print(f\"{exclamation}! We've finished epoch {epoch}!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OK I guess! We've finished epoch 0!\n", "OK I guess! We've finished epoch 1!\n", "OK I guess! We've finished epoch 2!\n", "OK I guess! We've finished epoch 3!\n", "OK I guess! We've finished epoch 4!\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(lambda o: show_progress(\"OK I guess\", o))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def make_show_progress(exclamation):\n", " def _inner(epoch): print(f\"{exclamation}! We've finished epoch {epoch}!\")\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Nice!! We've finished epoch 0!\n", "Nice!! We've finished epoch 1!\n", "Nice!! We've finished epoch 2!\n", "Nice!! We've finished epoch 3!\n", "Nice!! We've finished epoch 4!\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(make_show_progress(\"Nice!\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OK I guess! We've finished epoch 0!\n", "OK I guess! We've finished epoch 1!\n", "OK I guess! We've finished epoch 2!\n", "OK I guess! We've finished epoch 3!\n", "OK I guess! We've finished epoch 4!\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(partial(show_progress, \"OK I guess\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f2 = partial(show_progress, \"OK I guess\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callbacks as callable classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ProgressShowingCallback():\n", " def __init__(self, exclamation=\"Awesome\"): self.exclamation = exclamation\n", " def __call__(self, epoch): print(f\"{self.exclamation}! We've finished epoch {epoch}!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = ProgressShowingCallback(\"Just super\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Just super! We've finished epoch 0!\n", "Just super! We've finished epoch 1!\n", "Just super! We've finished epoch 2!\n", "Just super! We've finished epoch 3!\n", "Just super! We've finished epoch 4!\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(cb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiple callback funcs; `*args` and `**kwargs`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def f(*a, **b): print(f\"args: {a}; kwargs: {b}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "args: (3, 'a'); kwargs: {'thing1': 'hello'}\n" ] } ], "source": [ "f(3, 'a', thing1=\"hello\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def g(a,b,c=0): print(a,b,c)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 2 3\n" ] } ], "source": [ "args = [1,2]\n", "kwargs = {'c':3}\n", "g(*args, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def slow_calculation(cb=None):\n", " res = 0\n", " for i in range(5):\n", " if cb: cb.before_calc(i)\n", " res += i*i\n", " sleep(1)\n", " if cb: cb.after_calc(i, val=res)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PrintStepCallback():\n", " def before_calc(self, *args, **kwargs): print(f\"About to start\")\n", " def after_calc (self, *args, **kwargs): print(f\"Done step\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "About to start\n", "Done step\n", "About to start\n", "Done step\n", "About to start\n", "Done step\n", "About to start\n", "Done step\n", "About to start\n", "Done step\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(PrintStepCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PrintStatusCallback():\n", " def __init__(self): pass\n", " def before_calc(self, epoch, **kwargs): print(f\"About to start: {epoch}\")\n", " def after_calc (self, epoch, val, **kwargs): print(f\"After {epoch}: {val}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "About to start: 0\n", "After 0: 0\n", "About to start: 1\n", "After 1: 1\n", "About to start: 2\n", "After 2: 5\n", "About to start: 3\n", "After 3: 14\n", "About to start: 4\n", "After 4: 30\n" ] }, { "data": { "text/plain": [ "30" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(PrintStatusCallback())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Modifying behavior" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def slow_calculation(cb=None):\n", " res = 0\n", " for i in range(5):\n", " if cb and hasattr(cb,'before_calc'): cb.before_calc(i)\n", " res += i*i\n", " sleep(1)\n", " if cb and hasattr(cb,'after_calc'):\n", " if cb.after_calc(i, res):\n", " print(\"stopping early\")\n", " break\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PrintAfterCallback():\n", " def after_calc (self, epoch, val):\n", " print(f\"After {epoch}: {val}\")\n", " if val>10: return True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "After 0: 0\n", "After 1: 1\n", "After 2: 5\n", "After 3: 14\n", "stopping early\n" ] }, { "data": { "text/plain": [ "14" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slow_calculation(PrintAfterCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SlowCalculator():\n", " def __init__(self, cb=None): self.cb,self.res = cb,0\n", " \n", " def callback(self, cb_name, *args):\n", " if not self.cb: return\n", " cb = getattr(self.cb,cb_name, None)\n", " if cb: return cb(self, *args)\n", "\n", " def calc(self):\n", " for i in range(5):\n", " self.callback('before_calc', i)\n", " self.res += i*i\n", " sleep(1)\n", " if self.callback('after_calc', i):\n", " print(\"stopping early\")\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ModifyingCallback():\n", " def after_calc (self, calc, epoch):\n", " print(f\"After {epoch}: {calc.res}\")\n", " if calc.res>10: return True\n", " if calc.res<3: calc.res = calc.res*2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "calculator = SlowCalculator(ModifyingCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "After 0: 0\n", "After 1: 1\n", "After 2: 6\n", "After 3: 15\n", "stopping early\n" ] }, { "data": { "text/plain": [ "15" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "calculator.calc()\n", "calculator.res" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `__dunder__` thingies" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Anything that looks like `__this__` is, in some way, *special*. Python, or some library, can define some functions that they will call at certain documented times. For instance, when your class is setting up a new object, python will call `__init__`. These are defined as part of the python [data model](https://docs.python.org/3/reference/datamodel.html#object.__init__).\n", "\n", "For instance, if python sees `+`, then it will call the special method `__add__`. If you try to display an object in Jupyter (or lots of other places in Python) it will call `__repr__`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SloppyAdder():\n", " def __init__(self,o): self.o=o\n", " def __add__(self,b): return SloppyAdder(self.o + b.o + 0.01)\n", " def __repr__(self): return str(self.o)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.01" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = SloppyAdder(1)\n", "b = SloppyAdder(2)\n", "a+b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Special methods you should probably know about (see data model link above) are:\n", "\n", "- `__getitem__`\n", "- `__getattr__`\n", "- `__setattr__`\n", "- `__del__`\n", "- `__init__`\n", "- `__new__`\n", "- `__enter__`\n", "- `__exit__`\n", "- `__len__`\n", "- `__repr__`\n", "- `__str__`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `__getattr__` and `getattr`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A: a,b=1,2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = A()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "getattr(a, 'b')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "getattr(a, 'b' if random.random()>0.5 else 'a')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B:\n", " a,b=1,2\n", " def __getattr__(self, k):\n", " if k[0]=='_': raise AttributeError(k)\n", " return f'Hello from {k}'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b = B()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Hello from foo'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.foo" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }