{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp core.transform" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.core.imports import *\n", "from local.core.foundation import *\n", "from local.core.utils import *\n", "from local.core.dispatch import *\n", "from local.test import *\n", "from local.notebook.showdoc import show_doc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transforms\n", "\n", "> Definition of `Transform` and `Pipeline`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The classes here provide functionality for creating a composition of *partially reversible functions*. By \"partially reversible\" we mean that a transform can be `decode`d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).\n", "\n", "Classes are also provided and for composing transforms, and mapping them over collections. `Pipeline` is a transform which composes several `Transform`, knowing how to decode them or show an encoded item." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Types" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`ArrayImage`, `ArrayImageBW` and `ArrayMask` are subclasses of `ndarray` that know how to show themselves." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ArrayBase(ndarray):\n", " def __new__(cls, x, *args, **kwargs):\n", " if isinstance(x,tuple): super().__new__(cls, x, *args, **kwargs)\n", " if args or kwargs: raise RuntimeError('Unknown array init args')\n", " if not isinstance(x,ndarray): x = array(x)\n", " return x.view(cls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ArrayImageBase(ArrayBase):\n", " _show_args = {'cmap':'viridis'}\n", " def show(self, ctx=None, **kwargs):\n", " return show_image(self, ctx=ctx, **{**self._show_args, **kwargs})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ArrayImage(ArrayImageBase): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ArrayImageBW(ArrayImage): _show_args = {'cmap':'Greys'}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ArrayMask(ArrayImageBase): _show_args = {'alpha':0.5, 'cmap':'tab20'}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im = Image.open(TEST_IMAGE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im_t = ArrayImage(im)\n", "test_eq(type(im_t), ArrayImage)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im_t2 = ArrayMask(1)\n", "test_eq(type(im_t2), ArrayMask)\n", "test_eq(im_t2, array(1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ax = im_t.show(figsize=(2,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transform -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_tfm_methods = 'encodes','decodes','setups'\n", "\n", "class _TfmDict(dict):\n", " def __setitem__(self,k,v):\n", " if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n", " if k not in self: super().__setitem__(k,TypeDispatch())\n", " self[k].add(v)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _TfmMeta(type):\n", " def __new__(cls, name, bases, dict):\n", " res = super().__new__(cls, name, bases, dict)\n", " res.__signature__ = inspect.signature(res.__init__)\n", " return res\n", "\n", " def __call__(cls, *args, **kwargs):\n", " f = args[0] if args else None\n", " n = getattr(f,'__name__',None)\n", " for nm in _tfm_methods:\n", " if not hasattr(cls,nm): setattr(cls, nm, TypeDispatch())\n", " if callable(f) and n in _tfm_methods:\n", " getattr(cls,n).add(f)\n", " return f\n", " return super().__call__(*args, **kwargs)\n", "\n", " @classmethod\n", " def __prepare__(cls, name, bases): return _TfmDict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Transform(metaclass=_TfmMeta):\n", " \"Delegates (`__call__`,`decode`,`setup`) to (`encodes`,`decodes`,`setups`) if `split_idx` matches\"\n", " split_idx,init_enc,as_item_force,as_item,order = None,False,None,True,0\n", " def __init__(self, enc=None, dec=None, split_idx=None, as_item=False, order=None):\n", " self.split_idx,self.as_item = ifnone(split_idx, self.split_idx),as_item\n", " if order is not None: self.order=order\n", " self.init_enc = enc or dec\n", " if not self.init_enc: return\n", "\n", " # Passing enc/dec, so need to remove (base) class level enc/dec\n", " del(self.__class__.encodes,self.__class__.decodes,self.__class__.setups)\n", " self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()\n", " if enc:\n", " self.encodes.add(enc)\n", " self.order = getattr(enc,'order',self.order)\n", " if dec: self.decodes.add(dec)\n", "\n", " @property\n", " def use_as_item(self): return ifnone(self.as_item_force, self.as_item)\n", " def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)\n", " def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)\n", " def setup(self, items=None): return self.setups(items)\n", " def __repr__(self): return f'{self.__class__.__name__}: {self.use_as_item} {self.encodes} {self.decodes}'\n", "\n", " def _call(self, fn, x, split_idx=None, **kwargs):\n", " if split_idx!=self.split_idx and self.split_idx is not None: return x\n", " f = getattr(self, fn)\n", " if self.use_as_item or not is_listy(x): return self._do_call(f, x, **kwargs)\n", " res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)\n", " return retain_type(res, x)\n", "\n", " def _do_call(self, f, x, **kwargs):\n", " return x if f is None else retain_type(f(x, **kwargs), x, f.returns_none(x))\n", "\n", "add_docs(Transform, decode=\"Delegate to `decodes` to undo transform\", setup=\"Delegate to `setups` to set up transform\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class Transform[source]

\n", "\n", "> Transform(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`as_item`**=*`False`*)\n", "\n", "Delegates (`__call__`,`decode`,`setup`) to ([`encodes`](/tabular.rapids.html#encodes),`decodes`,[`setups`](/tabular.rapids.html#setups)) if `split_idx` matches" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `Transform` is the main building block of the fastai data pipelines. In the most general terms a transform can be any function you want to apply to your data, however the `Transform` class provides several mechanisms that make the process of building them easy and flexible.\n", "\n", "### The main `Transform` features:\n", "\n", "- **Type dispatch** - Type annotations are used to determine if a transform should be applied to the given argument. It also gives an option to provide several implementations and it choses the one to run based on the type. This is useful for example when running both independent and dependent variables through the pipeline where some transforms only make sense for one and not the other. Another usecase is designing a transform that handles different data formats. Note that if a transform takes multiple arguments only the type of the first one is used for dispatch. \n", "- **Handling of tuples** - When a tuple (or another collection satisfying `is_listy`) of data is passed to a transform it will get applied to each element separately. Most comonly it will be a *(x,y)* tuple, but it can be anything for example a list of images. You can opt out of this behavior by setting the flag `as_item=True`. For transforms that must always operate on the tuple level you can set `as_item_force=True` which takes precedence over `as_item`, an example of that is `PointScaler`.\n", "- **Reversability** - A transform can be made reversible by implementing the `decodes` method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes.\n", "- **Type propagation** - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like `ArrayImage` which is a thin wrapper of pytorches `Tensor`. You can opt out of this behavior by adding `->None` return type annotation.\n", "- **Preprocessing** - The `setup` method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.\n", "- **Filtering based on the dataset type** - By setting the `split_idx` flag you can make the transform be used only in a specific `DataSource` subset like in training, but not validation.\n", "- **Ordering** - You can set the `order` attribute which the `Pipeline` uses when it needs to merge two lists of transforms.\n", "- **Appending new behavior with decorators** - You can easily extend an existing `Transform` by creating `encodes` or `decodes` methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform.\n", "\n", "### Defining a `Transform`\n", "There are a few ways to create a transform with different ratios of simplicity to flexibility.\n", "- **Extending the `Transform` class** - Use inheritence to implement the methods you want.\n", "- **Passing methods to the constructor** - Instantiate the `Transform` class and pass your functions as `enc` and `dec` arguments.\n", "- **@Transform decorator** - Turn any function into a `Transform` by just adding a decorator - very straightforward if all you need is a single `encodes` implementation.\n", "- **Passing a function to fastai APIs** - Same as above, but when passing a function to other transform aware classes like `Pipeline` or `TfmdDS` you don't even need a decorator. Your function will get converted to a `Transform` automatically." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x): return x+1\n", "f1 = A()\n", "test_eq(f1(1), 2)\n", "\n", "class B(A): pass\n", "f2 = B()\n", "test_eq(f2(1), 2)\n", "\n", "class A(Transform): pass\n", "f3 = A()\n", "test_eq_type(f3(2), 2)\n", "test_eq_type(f3.decode(2.0), 2.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Transform` can be used as a decorator, to turn a function into a `Transform`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = Transform(lambda o:o//2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq_type(f(2), 1)\n", "test_eq_type(f.decode(2.0), 2.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@Transform\n", "def f(x): return x//2\n", "test_eq_type(f(2), 1)\n", "test_eq_type(f.decode(2.0), 2.0)\n", "\n", "@Transform\n", "def f(x): return x*2\n", "test_eq_type(f(2), 4)\n", "test_eq_type(f.decode(2.0), 2.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can derive from `Transform` and use `encodes` for your encoding function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "A: False (ArrayImage,object) -> encodes (ArrayImage,object) -> decodes" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class A(Transform):\n", " def encodes(self, x:ArrayImage): return -x\n", " def decodes(self, x:ArrayImage): return x+1\n", " def setups (self, x:ArrayImage): x.foo = 'a'\n", "f = A()\n", "t = f(im_t)\n", "test_eq(t, -im_t)\n", "test_eq(f(1), 1)\n", "test_eq(type(t), ArrayImage)\n", "test_eq(f.decode(t), -im_t+1)\n", "test_eq(f.decode(1), 1)\n", "f.setup(im_t)\n", "test_eq(im_t.foo, 'a')\n", "t2 = array(1)\n", "f.setup(t2)\n", "assert not hasattr(f2,'foo')\n", "f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without return annotation we get an `Int` back since that's what was passed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x:Int): return x//2\n", "@A\n", "def encodes(self, x:float): return x+1\n", "\n", "f = A()\n", "test_eq_type(f(Int(2)), Int(1))\n", "test_eq_type(f(2), 2)\n", "test_eq_type(f(2.), 3.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without return annotation we don't cast if we're not a subclass of the input type. If the annotation is a tuple, then any type in the tuple will match." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform):\n", " def encodes(self, x:(Int,float)): return x/2\n", " def encodes(self, x:(str,list)): return str(x)+'1'\n", "\n", "f = A()\n", "test_eq_type(f(Int(2)), 1.)\n", "test_eq_type(f(2), 2)\n", "test_eq_type(f(Float(2.)), Float(1.))\n", "test_eq_type(f('a'), 'a1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With return annotation `None` we get back whatever Python creates usually." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def func(x)->None: return x/2\n", "f = Transform(func)\n", "test_eq_type(f(2), 1.)\n", "test_eq_type(f(2.), 1.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `decodes` has no return annotation, but `encodes` created an `Int` and we pass that result here to `decode`, we end up with an `Int`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def func(x): return Int(x+1)\n", "def dec (x): return x-1\n", "f = Transform(func,dec)\n", "t = f(1)\n", "test_eq_type(t, Int(2))\n", "test_eq_type(f.decode(t), Int(1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the transform has `split_idx` then it's only applied if `split_idx` param matches." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f.split_idx = 1\n", "test_eq(f(1, split_idx=1),2)\n", "test_eq_type(f(1, split_idx=0), 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `as_item=True` the transform takes tuples as a whole and is applied to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): \n", " def encodes(self, xy): x,y=xy; return (x+y,y)\n", " def decodes(self, xy): x,y=xy; return (x-y,y)\n", "\n", "f = A(as_item=True)\n", "t = f((1,2))\n", "test_eq(t, (3,2))\n", "test_eq(f.decode(t), (1,2))\n", "f.split_idx = 1\n", "test_eq(f((1,2), split_idx=1), (3,2))\n", "test_eq(f((1,2), split_idx=0), (1,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AL(Transform): pass\n", "@AL\n", "def encodes(self, x): return L(x_+1 for x_ in x)\n", "@AL\n", "def decodes(self, x): return L(x_-1 for x_ in x)\n", "\n", "f = AL(as_item=True)\n", "t = f([1,2])\n", "test_eq(t, [2,3])\n", "test_eq(f.decode(t), [1,2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `as_item=False` the transform is applied to each element of a listy input." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def neg_int(x:numbers.Integral): return -x\n", "\n", "f = Transform(neg_int, as_item=False)\n", "test_eq(f([1]), (-1,))\n", "test_eq(f([1.]), (1.,))\n", "test_eq(f([1.,2,3.]), (1.,-2,3.))\n", "test_eq(f.decode([1,2]), (1,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class InplaceTransform(Transform):\n", " \"A `Transform` that modifies in-place and just returns whatever it's passed\"\n", " def _call(self, fn, x, split_idx=None, **kwargs):\n", " super()._call(fn,x,split_idx,**kwargs)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(InplaceTransform): pass\n", "@A\n", "def encodes(self, x:pd.Series): x.fillna(10, inplace=True)\n", "f = A()\n", "test_eq_type(f(pd.Series([1,2,None])),pd.Series([1,2,10]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TupleTransform" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TupleTransform(Transform):\n", " \"`Transform` that always treats `as_item` as `False`\"\n", " as_item_force=False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ItemTransform (Transform):\n", " \"`Transform` that always treats `as_item` as `True`\"\n", " as_item_force=True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def float_to_int(x:(float,int)): return Int(x)\n", "\n", "f = TupleTransform(float_to_int)\n", "test_eq_type(f([1.]), (Int(1),))\n", "test_eq_type(f([1]), (Int(1),))\n", "test_eq_type(f(['1']), ('1',))\n", "test_eq_type(f([1,'1']), (Int(1),'1'))\n", "test_eq(f.decode([1]), [1])\n", "\n", "test_eq_type(f(Tuple(1.)), Tuple(Int(1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(TupleTransform): pass\n", "class C(TupleTransform): pass\n", "f = B()\n", "test_eq(f([1]), [1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@B\n", "def encodes(self, x:int): return x+1\n", "@B\n", "def encodes(self, x:str): return x+'1'\n", "@B\n", "def encodes(self, x)->None: return str(x)+'!'\n", "\n", "b,c = B(),C()\n", "test_eq(b([1]), [2])\n", "test_eq(b(['1']), ('11',))\n", "test_eq(b([1.0]), ('1.0!',))\n", "test_eq(c([1]), [1])\n", "test_eq(b([1,2]), (2,3))\n", "test_eq(b.decode([2]), [2])\n", "assert pickle.loads(pickle.dumps(b))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@B\n", "def decodes(self, x:int): return x-1\n", "test_eq(b.decode([2]), [1])\n", "test_eq(b.decode(('2',)), ('2',))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Non-type-constrained functions are applied to all elements of a tuple." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(TupleTransform): pass\n", "@A\n", "def encodes(self, x): return x+1\n", "@A\n", "def decodes(self, x): return x-1\n", "\n", "f = A()\n", "t = f((1,2.0))\n", "test_eq_type(t, (2,3.0))\n", "test_eq_type(f.decode(t), (1,2.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Type-constrained functions are applied to only matching elements of a tuple, and return annotations are only applied where matching." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(TupleTransform):\n", " def encodes(self, x:int): return Int(x+1)\n", " def encodes(self, x:str): return x+'1'\n", " def decodes(self, x:Int): return x//2\n", "\n", "f = B()\n", "start = (1.,2,'3')\n", "t = f(start)\n", "test_eq_type(t, (1.,Int(3),'31'))\n", "test_eq(f.decode(t), (1.,Int(1),'31'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same behavior also works with `typing` module type classes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x:numbers.Integral): return x+1\n", "@A\n", "def encodes(self, x:float): return x*3\n", "@A\n", "def decodes(self, x:int): return x-1\n", "\n", "f = A()\n", "start = 1.0\n", "t = f(start)\n", "test_eq(t, 3.)\n", "test_eq(f.decode(t), 3)\n", "\n", "f = A(as_item=False)\n", "start = (1.,2,3.)\n", "t = f(start)\n", "test_eq(t, (3.,3,9.))\n", "test_eq(f.decode(t), (3.,2,9.))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transform accepts lists" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def a(x): return L(x_+1 for x_ in x)\n", "def b(x): return L(x_-1 for x_ in x)\n", "f = TupleTransform(a,b)\n", "\n", "t = f((L(1,2),))\n", "test_eq(t, (L(2,3),))\n", "test_eq(f.decode(t), (L(1,2),))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Func -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_func(t, name, *args, **kwargs):\n", " \"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined\"\n", " f = getattr(t, name, noop)\n", " return f if not (args or kwargs) else partial(f, *args, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This works for any kind of `t` supporting `getattr`, so a class or a module." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(get_func(operator, 'neg', 2)(), -2)\n", "test_eq(get_func(operator.neg, '__call__')(2), -2)\n", "test_eq(get_func(list, 'foobar')([2]), [2])\n", "t = get_func(torch, 'zeros', dtype=torch.int64)(5)\n", "test_eq(t.dtype, torch.int64)\n", "a = [2,1]\n", "get_func(list, 'sort')(a)\n", "test_eq(a, [1,2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms are built with multiple-dispatch: a given function can have several methods depending on the type of the object received. This is done directly with the `TypeDispatch` module and type-annotation in `Transform`, but you can also use the following class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Func():\n", " \"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type\"\n", " def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs\n", " def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'\n", " def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)\n", " def __call__(self,t): return mapped(self._get, t)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can call the `Func` object on any module name or type, even a list of types. It will return the corresponding function (with a default to `noop` if nothing is found) or list of functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(Func('sqrt')(math), math.sqrt)\n", "test_eq(Func('sqrt')(torch), torch.sqrt)\n", "\n", "@patch\n", "def powx(x:math, a): return math.pow(x,a)\n", "@patch\n", "def powx(x:torch, a): return torch.pow(x,a)\n", "tst = Func('powx',a=2)([math, torch])\n", "test_eq([f.func for f in tst], [math.powx, torch.powx])\n", "for t in tst: test_eq(t.keywords, {'a': 2})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _Sig():\n", " def __getattr__(self,k):\n", " def _inner(*args, **kwargs): return Func(k, *args, **kwargs)\n", " return _inner\n", "\n", "Sig = _Sig()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Sig[source]

\n", "\n", "> Sig(**\\*`args`**, **\\*\\*`kwargs`**)\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Sig, name=\"Sig\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Sig` is just sugar-syntax to create a `Func` object more easily with the syntax `Sig.name(*args, **kwargs)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = Sig.sqrt()\n", "test_eq(f(math), math.sqrt)\n", "test_eq(f(torch), torch.sqrt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pipeline -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):\n", " \"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order\"\n", " if reverse: tfms = reversed(tfms)\n", " for f in tfms:\n", " if not is_enc: f = f.decode\n", " x = f(x, **kwargs)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def to_int (x): return Int(x)\n", "def to_float(x): return Float(x)\n", "def double (x): return x*2\n", "def half(x)->None: return x/2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_compose(a, b, *fs): test_eq_type(compose_tfms(a, tfms=map(Transform,fs)), b)\n", "\n", "test_compose(1, Int(1), to_int)\n", "test_compose(1, Float(1), to_int,to_float)\n", "test_compose(1, Float(2), to_int,to_float,double)\n", "test_compose(2.0, 2.0, to_int,double,half)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform):\n", " def encodes(self, x:float): return Float(x+1)\n", " def decodes(self, x): return x-1\n", " \n", "tfms = [A(), Transform(math.sqrt)]\n", "t = compose_tfms(3., tfms=tfms)\n", "test_eq_type(t, Float(2.))\n", "test_eq(compose_tfms(t, tfms=tfms, is_enc=False), 1.)\n", "test_eq(compose_tfms(4., tfms=tfms, reverse=True), 3.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [A(as_item=False), Transform(math.sqrt, as_item=False)]\n", "test_eq(compose_tfms((9,3.), tfms=tfms), (3,2.))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mk_transform(f, as_item=True):\n", " \"Convert function `f` to `Transform` if it isn't already one\"\n", " f = instantiate(f)\n", " return f if isinstance(f,Transform) else Transform(f, as_item=as_item)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def neg(x): return -x\n", "test_eq(type(mk_transform(neg)), Transform)\n", "test_eq(type(mk_transform(math.sqrt)), Transform)\n", "test_eq(type(mk_transform(lambda a:a*2)), Transform)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def gather_attrs(o, k, nm):\n", " \"Used in __getattr__ to collect all attrs `k` from `self.{nm}`\"\n", " if k.startswith('_') or k==nm: raise AttributeError(k)\n", " att = getattr(o,nm)\n", " res = [t for t in att.attrgot(k) if t is not None]\n", " if not res: raise AttributeError(k)\n", " return res[0] if len(res)==1 else L(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def gather_attr_names(o, nm):\n", " \"Used in __dir__ to collect all attrs `k` from `self.{nm}`\"\n", " return L(getattr(o,nm)).map(dir).concat().unique()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Pipeline:\n", " \"A pipeline of composed (for encode/decode) transforms, setup with types\"\n", " def __init__(self, funcs=None, as_item=False, split_idx=None):\n", " self.split_idx,self.default = split_idx,None\n", " if isinstance(funcs, Pipeline): self.fs = funcs.fs\n", " else:\n", " if isinstance(funcs, Transform): funcs = [funcs]\n", " self.fs = L(ifnone(funcs,[noop])).map(mk_transform).sorted(key='order')\n", " for f in self.fs:\n", " name = camel2snake(type(f).__name__)\n", " a = getattr(self,name,None)\n", " if a is not None: f = L(a)+f\n", " setattr(self, name, f)\n", " self.set_as_item(as_item)\n", "\n", " def set_as_item(self, as_item):\n", " self.as_item = as_item\n", " for f in self.fs: f.as_item = as_item\n", "\n", " def setup(self, items=None):\n", " tfms = self.fs[:]\n", " self.fs.clear()\n", " for t in tfms: self.add(t,items)\n", "\n", " def add(self,t, items=None):\n", " t.setup(items)\n", " self.fs.append(t)\n", "\n", " def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)\n", " def __repr__(self): return f\"Pipeline: {self.fs}\"\n", " def __getitem__(self,i): return self.fs[i]\n", " def __setstate__(self,data): self.__dict__.update(data)\n", " def __getattr__(self,k): return gather_attrs(self, k, 'fs')\n", " def __dir__(self): return super().__dir__() + gather_attr_names(self, 'fs')\n", "\n", " def decode (self, o, full=True):\n", " if full: return compose_tfms(o, tfms=self.fs, is_enc=False, reverse=True, split_idx=self.split_idx)\n", " #Not full means we decode up to the point the item knows how to show itself.\n", " for f in reversed(self.fs):\n", " if self._is_showable(o): return o\n", " o = f.decode(o, split_idx=self.split_idx)\n", " return o\n", "\n", " def show(self, o, ctx=None, **kwargs):\n", " o = self.decode(o, full=False)\n", " o1 = [o] if self.as_item or not is_listy(o) else o\n", " for o_ in o1:\n", " if hasattr(o_, 'show'): ctx = o_.show(ctx=ctx, **kwargs)\n", " return ctx\n", "\n", " def _is_showable(self, o):\n", " return all(hasattr(o_, 'show') for o_ in o) if is_listy(o) else hasattr(o, 'show')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "add_docs(Pipeline,\n", " __call__=\"Compose `__call__` of all `fs` on `o`\",\n", " decode=\"Compose `decode` of all `fs` on `o`\",\n", " show=\"Show `o`, a single item from a tuple, decoding as needed\",\n", " add=\"Add transform `t`\",\n", " set_as_item=\"Set value of `as_item` for all transforms\",\n", " setup=\"Call each tfm's `setup` in order\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Pipeline` is a wrapper for `compose_tfm`. You can pass instances of `Transform` or regular functions in `funcs`, the `Pipeline` will wrap them all in `Transform` (and instantiate them if needed) during the initialization. It handles the transform `setup` by adding them one at a time and calling setup on each, goes through them in order in `__call__` or `decode` and can `show` an object by applying decoding the transforms up until the point it gets an object that knows how to show itself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Empty pipeline is noop\n", "pipe = Pipeline()\n", "test_eq(pipe(1), 1)\n", "pipe.set_as_item(False)\n", "test_eq(pipe((1,)), (1,))\n", "# Check pickle works\n", "assert pickle.loads(pickle.dumps(pipe))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class IntFloatTfm(Transform):\n", " def encodes(self, x): return Int(x)\n", " def decodes(self, x): return Float(x)\n", " foo=1\n", "\n", "int_tfm=IntFloatTfm()\n", "\n", "def neg(x): return -x\n", "neg_tfm = Transform(neg, neg)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([neg_tfm, int_tfm])\n", "\n", "start = 2.0\n", "t = pipe(start)\n", "test_eq_type(t, Int(-2))\n", "test_eq_type(pipe.decode(t), Float(start))\n", "test_stdout(lambda:pipe.show(t), '-2')\n", "\n", "pipe.set_as_item(False)\n", "test_stdout(lambda:pipe.show(pipe((1.,2.))), '-1\\n-2')\n", "test_eq(pipe.foo, 1)\n", "assert 'foo' in dir(pipe)\n", "assert 'int_float_tfm' in dir(pipe)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms are available as attributes named with the snake_case version of the names of their types. Attributes in transforms can be directly accessed as attributes of the pipeline." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(pipe.int_float_tfm, int_tfm)\n", "test_eq(pipe.foo, 1)\n", "\n", "pipe = Pipeline([int_tfm, int_tfm])\n", "pipe.int_float_tfm\n", "test_eq(pipe.int_float_tfm[0], int_tfm)\n", "test_eq(pipe.foo, [1,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check opposite order\n", "pipe = Pipeline([int_tfm,neg_tfm])\n", "t = pipe(start)\n", "test_eq(t, -2)\n", "test_stdout(lambda:pipe.show(t), '-2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform):\n", " def encodes(self, x): return int(x)\n", " def decodes(self, x): return Float(x)\n", " \n", "pipe = Pipeline([neg_tfm, A])\n", "t = pipe(start)\n", "test_eq_type(t, -2)\n", "test_eq_type(pipe.decode(t), Float(start))\n", "test_stdout(lambda:pipe.show(t), '-2.0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s2 = (1,2)\n", "pipe.set_as_item(False)\n", "t = pipe(s2)\n", "test_eq_type(t, (-1,-2))\n", "test_eq_type(pipe.decode(t), (Float(1.),Float(2.)))\n", "test_stdout(lambda:pipe.show(t), '-1.0\\n-2.0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(Transform):\n", " def encodes(self, x): return x+1\n", " def decodes(self, x): return x-1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "\n", "def f1(x:ArrayImage): return -x\n", "def f2(x): return Image.open(x).resize((128,128))\n", "def f3(x:Image.Image): return(ArrayImage(array(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([f2,f3,f1])\n", "t = pipe(TEST_IMAGE)\n", "test_eq(type(t), ArrayImage)\n", "test_eq(t, -array(f3(f2(TEST_IMAGE))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "pipe = Pipeline([f2,f3])\n", "t = pipe(TEST_IMAGE)\n", "ax = pipe.show(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Check filtering is properly applied\n", "add1 = B()\n", "add1.split_idx = 1\n", "pipe = Pipeline([neg_tfm, A(), add1])\n", "test_eq(pipe(start), -2)\n", "pipe.split_idx=1\n", "test_eq(pipe(start), -1)\n", "pipe.split_idx=0\n", "test_eq(pipe(start), -2)\n", "for t in [None, 0, 1]:\n", " pipe.split_idx=t\n", " test_eq(pipe.decode(pipe(start)), start)\n", " test_stdout(lambda: pipe.show(pipe(start)), \"-2.0\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TODO: method examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Pipeline.__call__[source]

\n", "\n", "> Pipeline.__call__(**`o`**)\n", "\n", "Compose `__call__` of all `fs` on `o`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Pipeline.__call__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Pipeline.decode[source]

\n", "\n", "> Pipeline.decode(**`o`**, **`full`**=*`True`*)\n", "\n", "Compose `decode` of all `fs` on `o`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Pipeline.decode)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Pipeline.setup[source]

\n", "\n", "> Pipeline.setup(**`items`**=*`None`*)\n", "\n", "Call each tfm's `setup` in order" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Pipeline.setup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "During the setup, the `Pipeline` starts with no transform and adds them one at a time, so that during its setup, each transform gets the items processed up to its point and not after." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test is with TfmdList" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core_foundation.ipynb.\n", "Converted 01a_core_utils.ipynb.\n", "Converted 01b_core_dispatch.ipynb.\n", "Converted 01c_core_transform.ipynb.\n", "Converted 02_core_script.ipynb.\n", "Converted 03_torchcore.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_data_load.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 71_callback_tensorboard.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }