{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "#default_exp dispatch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "from fastcore.imports import *\n", "from fastcore.foundation import *\n", "from fastcore.utils import *\n", "\n", "from collections import defaultdict" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *\n", "from fastcore.test import *\n", "from fastcore.nb_imports import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Type dispatch\n", "\n", "> Basic single and dual parameter dispatch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helpers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "def lenient_issubclass(cls, types):\n", " \"If possible return whether `cls` is a subclass of `types`, otherwise return False.\"\n", " if cls is object and types is not object: return False # treat `object` as highest level\n", " try: return isinstance(cls, types) or issubclass(cls, types)\n", " except: return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert not lenient_issubclass(typing.Collection, list)\n", "assert lenient_issubclass(list, typing.Collection)\n", "assert lenient_issubclass(typing.Collection, object)\n", "assert lenient_issubclass(typing.List, typing.Collection)\n", "assert not lenient_issubclass(typing.Collection, typing.List)\n", "assert not lenient_issubclass(object, typing.Callable)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):\n", " \"Return a new list containing all items from the iterable sorted topologically\"\n", " l,res = L(list(iterable)),[]\n", " for _ in range(len(l)):\n", " t = l.reduce(lambda x,y: y if cmp(y,x) else x)\n", " res.append(t), l.remove(t)\n", " return res[::-1] if reverse else res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "td = [3, 1, 2, 5]\n", "test_eq(sorted_topologically(td), [1, 2, 3, 5])\n", "test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "td = {int:1, numbers.Number:2, numbers.Integral:3}\n", "test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "td = [numbers.Integral, tuple, list, int, dict]\n", "td = sorted_topologically(td, cmp=lenient_issubclass)\n", "assert td.index(int) < td.index(numbers.Integral)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "def _chk_defaults(f, ann):\n", " pass\n", "# Implementation removed until we can figure out how to do this without `inspect` module\n", "# try: # Some callables don't have signatures, so ignore those errors\n", "# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]\n", "# if any(p.default!=inspect.Parameter.empty for p in params):\n", "# warn(f\"{f.__name__} has default params. These will be ignored.\")\n", "# except ValueError: pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "def _p2_anno(f):\n", " \"Get the 1st 2 annotations of `f`, defaulting to `object`\"\n", " hints = type_hints(f)\n", " ann = [o for n,o in hints.items() if n!='return']\n", " if callable(f): _chk_defaults(f, ann)\n", " while len(ann)<2: ann.append(object)\n", " return ann[:2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "#hide\n", "def _f(a): pass\n", "test_eq(_p2_anno(_f), (object,object))\n", "def _f(a, b): pass\n", "test_eq(_p2_anno(_f), (object,object))\n", "def _f(a:None, b)->str: pass\n", "test_eq(_p2_anno(_f), (NoneType,object))\n", "def _f(a:str, b)->float: pass\n", "test_eq(_p2_anno(_f), (str,object))\n", "def _f(a:None, b:str)->float: pass\n", "test_eq(_p2_anno(_f), (NoneType,str))\n", "def _f(a:int, b:int)->float: pass\n", "test_eq(_p2_anno(_f), (int,int))\n", "def _f(self, a:int, b:int): pass\n", "test_eq(_p2_anno(_f), (int,int))\n", "def _f(a:int, b:str)->float: pass\n", "test_eq(_p2_anno(_f), (int,str))\n", "test_eq(_p2_anno(attrgetter('foo')), (object,object))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [ { "data": { "text/plain": [ "([object, object], [int, object])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#hide\n", "# Disabled until _chk_defaults fixed\n", "# def _f(x:int,y:int=10): pass\n", "# test_warns(lambda: _p2_anno(_f))\n", "def _f(x:int,y=10): pass\n", "_p2_anno(None),_p2_anno(_f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TypeDispatch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:\n", "\n", "```julia\n", "collide_with(x::Asteroid, y::Asteroid) = ... \n", "# deal with asteroid hitting asteroid\n", "\n", "collide_with(x::Asteroid, y::Spaceship) = ... \n", "# deal with asteroid hitting spaceship\n", "\n", "collide_with(x::Spaceship, y::Asteroid) = ... \n", "# deal with spaceship hitting asteroid\n", "\n", "collide_with(x::Spaceship, y::Spaceship) = ... \n", "# deal with spaceship hitting spaceship\n", "```\n", "\n", "Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.\n", "\n", "The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "class _TypeDict:\n", " def __init__(self): self.d,self.cache = {},{}\n", "\n", " def _reset(self):\n", " self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}\n", " self.cache = {}\n", "\n", " def add(self, t, f):\n", " \"Add type `t` and function `f`\"\n", " if not isinstance(t,tuple): t=tuple(L(t))\n", " for t_ in t: self.d[t_] = f\n", " self._reset()\n", "\n", " def all_matches(self, k):\n", " \"Find first matching type that is a super-class of `k`\"\n", " if k not in self.cache:\n", " types = [f for f in self.d if lenient_issubclass(k,f)]\n", " self.cache[k] = [self.d[o] for o in types]\n", " return self.cache[k]\n", "\n", " def __getitem__(self, k):\n", " \"Find first matching type that is a super-class of `k`\"\n", " res = self.all_matches(k)\n", " return res[0] if len(res) else None\n", "\n", " def __repr__(self): return self.d.__repr__()\n", " def first(self): return first(self.d.values())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#export\n", "class TypeDispatch:\n", " \"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`\"\n", " def __init__(self, funcs=(), bases=()):\n", " self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))\n", " for o in L(funcs): self.add(o)\n", " self.inst = None\n", " self.owner = None\n", "\n", " def add(self, f):\n", " \"Add type `t` and function `f`\"\n", " if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n", " else: a0,a1 = _p2_anno(f)\n", " t = self.funcs.d.get(a0)\n", " if t is None:\n", " t = _TypeDict()\n", " self.funcs.add(a0, t)\n", " t.add(a1, f)\n", "\n", " def first(self):\n", " \"Get first function in ordered dict of type:func.\"\n", " return self.funcs.first().first()\n", " \n", " def returns(self, x):\n", " \"Get the return type of annotation of `x`.\"\n", " return anno_ret(self[type(x)])\n", "\n", " def _attname(self,k): return getattr(k,'__name__',str(k))\n", " def __repr__(self):\n", " r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, \"__name__\", type(v).__name__)}'\n", " for k in self.funcs.d for l,v in self.funcs[k].d.items()]\n", " r = r + [o.__repr__() for o in self.bases]\n", " return '\\n'.join(r)\n", "\n", " def __call__(self, *args, **kwargs):\n", " ts = L(args).map(type)[:2]\n", " f = self[tuple(ts)]\n", " if not f: return args[0]\n", " if isinstance(f, staticmethod): f = f.__func__\n", " elif self.inst is not None: f = MethodType(f, self.inst)\n", " elif self.owner is not None: f = MethodType(f, self.owner)\n", " return f(*args, **kwargs)\n", "\n", " def __get__(self, inst, owner):\n", " self.inst = inst\n", " self.owner = owner\n", " return self\n", "\n", " def __getitem__(self, k):\n", " \"Find first matching type that is a super-class of `k`\"\n", " k = L(k)\n", " while len(k)<2: k.append(object)\n", " r = self.funcs.all_matches(k[0])\n", " for t in r:\n", " o = t[k[1]]\n", " if o is not None: return o\n", " for base in self.bases:\n", " res = base[k]\n", " if res is not None: return res\n", " return None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To demonstrate how `TypeDispatch` works, we define a set of functions that accept a variety of input types, specified with different type annotations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def f2(x:int, y:float): return x+y #int and float for 2nd arg\n", "def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric\n", "def f_ni2(x:int): return x #integer\n", "def f_bll(x:(bool,list)): return x #bool or list\n", "def f_num(x:numbers.Number): return x #Number (root of numerics) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can optionally initialize `TypeDispatch` with a list of functions we want to search. Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(bool,object) -> f_bll\n", "(int,object) -> f_ni2\n", "(Integral,object) -> f_nin\n", "(Number,object) -> f_num\n", "(list,object) -> f_bll\n", "(object,object) -> NoneType" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n", "t" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that only the first two arguments are used for `TypeDispatch`. If your function only contains one argument, the second parameter will be shown as `object`. If you pass `None` into `TypeDispatch`, then this will be displayed as `(object, object) -> NoneType`.\n", "\n", "`TypeDispatch` is a dictionary-like object, which means that you can retrieve a function by the associated type annotation. For example, the statement:\n", "\n", "```py\n", "t[float]\n", "```\n", "Will return `f_num` because that is the matching function that has a type annotation that is a super-class of of `float` - `numbers.Number`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert issubclass(float, numbers.Number)\n", "test_eq(t[float], f_num)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same is true for other types as well:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(t[np.int32], f_nin)\n", "test_eq(t[bool], f_bll)\n", "test_eq(t[list], f_bll)\n", "test_eq(t[np.int32], f_nin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you try to get a type that doesn't match, `TypeDispatch` will return `None`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(t[str], None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/markdown": [ "
TypeDispatch.add
[source]TypeDispatch.add
(**`f`**)\n",
"\n",
"Add type `t` and function `f`"
],
"text/plain": [
"TypeDispatch.__call__
[source]TypeDispatch.__call__
(**\\*`args`**, **\\*\\*`kwargs`**)\n",
"\n",
"Call self as a function."
],
"text/plain": [
"TypeDispatch.returns
[source]TypeDispatch.returns
(**`x`**)\n",
"\n",
"Get the return type of annotation of `x`."
],
"text/plain": [
"