{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp transform" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastcore.imports import *\n", "from fastcore.foundation import *\n", "from fastcore.utils import *\n", "from fastcore.dispatch import *\n", "import inspect" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *\n", "from fastcore.test import *\n", "from fastcore.nb_imports import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transforms\n", "\n", "> Definition of `Transform` and `Pipeline`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The classes here provide functionality for creating a composition of *partially reversible functions*. By \"partially reversible\" we mean that a transform can be `decode`d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).\n", "\n", "Classes are also provided and for composing transforms, and mapping them over collections. `Pipeline` is a transform which composes several `Transform`, knowing how to decode them or show an encoded item." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transform -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_tfm_methods = 'encodes','decodes','setups'\n", "\n", "class _TfmDict(dict):\n", " def __setitem__(self,k,v):\n", " if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n", " if k not in self: super().__setitem__(k,TypeDispatch())\n", " self[k].add(v)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _TfmMeta(type):\n", " def __new__(cls, name, bases, dict):\n", " res = super().__new__(cls, name, bases, dict)\n", " for nm in _tfm_methods:\n", " base_td = [getattr(b,nm,None) for b in bases]\n", " if nm in res.__dict__: getattr(res,nm).bases = base_td\n", " else: setattr(res, nm, TypeDispatch(bases=base_td))\n", " res.__signature__ = inspect.signature(res.__init__)\n", " return res\n", "\n", " def __call__(cls, *args, **kwargs):\n", " f = args[0] if args else None\n", " n = getattr(f,'__name__',None)\n", " if callable(f) and n in _tfm_methods:\n", " getattr(cls,n).add(f)\n", " return f\n", " return super().__call__(*args, **kwargs)\n", "\n", " @classmethod\n", " def __prepare__(cls, name, bases): return _TfmDict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_name(o):\n", " if hasattr(o,'__qualname__'): return o.__qualname__\n", " if hasattr(o,'__name__'): return o.__name__\n", " return o.__class__.__name__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Transform(metaclass=_TfmMeta):\n", " \"Delegates (`__call__`,`decode`,`setup`) to (encodes,decodes,setups) if `split_idx` matches\"\n", " split_idx,init_enc,order,train_setup = None,None,0,None\n", " def __init__(self, enc=None, dec=None, split_idx=None, order=None):\n", " self.split_idx = ifnone(split_idx, self.split_idx)\n", " if order is not None: self.order=order\n", " self.init_enc = enc or dec\n", " if not self.init_enc: return\n", "\n", " self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()\n", " if enc:\n", " self.encodes.add(enc)\n", " self.order = getattr(enc,'order',self.order)\n", " if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())\n", " self._name = _get_name(enc)\n", " if dec: self.decodes.add(dec)\n", "\n", " @property\n", " def name(self): return getattr(self, '_name', _get_name(self))\n", " def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)\n", " def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)\n", " def __repr__(self): return f'{self.name}:\\nencodes: {self.encodes}decodes: {self.decodes}'\n", "\n", " def setup(self, items=None, train_setup=False):\n", " train_setup = train_setup if self.train_setup is None else self.train_setup\n", " return self.setups(getattr(items, 'train', items) if train_setup else items)\n", "\n", " def _call(self, fn, x, split_idx=None, **kwargs):\n", " if split_idx!=self.split_idx and self.split_idx is not None: return x\n", " return self._do_call(getattr(self, fn), x, **kwargs)\n", "\n", " def _do_call(self, f, x, **kwargs):\n", " if not _is_tuple(x):\n", " if f is None: return x\n", " ret = f.returns_none(x) if hasattr(f,'returns_none') else None\n", " return retain_type(f(x, **kwargs), x, ret)\n", " res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)\n", " return retain_type(res, x)\n", "\n", "add_docs(Transform, decode=\"Delegate to decodes to undo transform\", setup=\"Delegate to setups to set up transform\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class Transform[source]

\n", "\n", "> Transform(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*)\n", "\n", "Delegates (`__call__`,`decode`,`setup`) to (encodes,decodes,setups) if `split_idx` matches" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `Transform` is the main building block of the fastai data pipelines. In the most general terms a transform can be any function you want to apply to your data, however the `Transform` class provides several mechanisms that make the process of building them easy and flexible." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The main `Transform` features:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Type dispatch** - Type annotations are used to determine if a transform should be applied to the given argument. It also gives an option to provide several implementations and it choses the one to run based on the type. This is useful for example when running both independent and dependent variables through the pipeline where some transforms only make sense for one and not the other. Another usecase is designing a transform that handles different data formats. Note that if a transform takes multiple arguments only the type of the first one is used for dispatch. \n", "- **Handling of tuples** - When a tuple (or a subclass of tuple) of data is passed to a transform it will get applied to each element separately. You can opt out of this behavior by passing a list or an `L`, as only tuples gets this specific behavior. An alternative is to use `ItemTransform` defined below, which will always take the input as a whole.\n", "- **Reversability** - A transform can be made reversible by implementing the decodes method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the `decode` method that is used to decode will be applied over each element of a tuple separately.\n", "- **Type propagation** - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like `ArrayImage` which is a thin wrapper of pytorch's `Tensor`. You can opt out of this behavior by adding `->None` return type annotation.\n", "- **Preprocessing** - The `setup` method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.\n", "- **Filtering based on the dataset type** - By setting the `split_idx` flag you can make the transform be used only in a specific `DataSource` subset like in training, but not validation.\n", "- **Ordering** - You can set the `order` attribute which the `Pipeline` uses when it needs to merge two lists of transforms.\n", "- **Appending new behavior with decorators** - You can easily extend an existing `Transform` by creating encodes or decodes methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining a `Transform`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are a few ways to create a transform with different ratios of simplicity to flexibility.\n", "- **Extending the `Transform` class** - Use inheritence to implement the methods you want.\n", "- **Passing methods to the constructor** - Instantiate the `Transform` class and pass your functions as `enc` and `dec` arguments.\n", "- **@Transform decorator** - Turn any function into a `Transform` by just adding a decorator - very straightforward if all you need is a single encodes implementation.\n", "- **Passing a function to fastai APIs** - Same as above, but when passing a function to other transform aware classes like `Pipeline` or `TfmdDS` you don't even need a decorator. Your function will get converted to a `Transform` automatically." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x): return x+1\n", "f1 = A()\n", "test_eq(f1(1), 2)\n", "\n", "class B(A): pass\n", "@B\n", "def decodes(self, x): return x-1\n", "f2 = B()\n", "test_eq(f2(1), 2)\n", "test_eq(f2.decode(2), 1)\n", "test_eq(f1.decode(2), 2)\n", "\n", "class A(Transform): pass\n", "f3 = A()\n", "test_eq_type(f3(2), 2)\n", "test_eq_type(f3.decode(2.0), 2.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Transform` can be used as a decorator, to turn a function into a `Transform`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = Transform(lambda o:o//2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq_type(f(2), 1)\n", "test_eq_type(f.decode(2.0), 2.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@Transform\n", "def f(x): return x//2\n", "test_eq_type(f(2), 1)\n", "test_eq_type(f.decode(2.0), 2.0)\n", "\n", "@Transform\n", "def f(x): return x*2\n", "test_eq_type(f(2), 4)\n", "test_eq_type(f.decode(2.0), 2.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can derive from `Transform` and use encodes for your encoding function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ArrayImage(ndarray):\n", " _show_args = {'cmap':'viridis'}\n", " def __new__(cls, x, *args, **kwargs):\n", " if isinstance(x,tuple): super().__new__(cls, x, *args, **kwargs)\n", " if args or kwargs: raise RuntimeError('Unknown array init args')\n", " if not isinstance(x,ndarray): x = array(x)\n", " return x.view(cls)\n", " \n", " def show(self, ctx=None, figsize=None, **kwargs):\n", " if ctx is None: _,ctx = plt.subplots(figsize=figsize)\n", " ctx.imshow(im, **{**self._show_args, **kwargs})\n", " ctx.axis('off')\n", " return ctx\n", " \n", "im = Image.open(TEST_IMAGE)\n", "im_t = ArrayImage(im)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "A:\n", "encodes: (ArrayImage,object) -> encodes\n", "decodes: (ArrayImage,object) -> decodes" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class A(Transform):\n", " def encodes(self, x:ArrayImage): return -x\n", " def decodes(self, x:ArrayImage): return x+1\n", " def setups (self, x:ArrayImage): x.foo = 'a'\n", "f = A()\n", "t = f(im_t)\n", "test_eq(t, -im_t)\n", "test_eq(f(1), 1)\n", "test_eq(type(t), ArrayImage)\n", "test_eq(f.decode(t), -im_t+1)\n", "test_eq(f.decode(1), 1)\n", "f.setup(im_t)\n", "test_eq(im_t.foo, 'a')\n", "t2 = array(1)\n", "f.setup(t2)\n", "assert not hasattr(f2,'foo')\n", "f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without return annotation we get an `Int` back since that's what was passed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x:Int): return x//2\n", "@A\n", "def encodes(self, x:float): return x+1\n", "\n", "f = A()\n", "test_eq_type(f(Int(2)), Int(1))\n", "test_eq_type(f(2), 2)\n", "test_eq_type(f(2.), 3.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without return annotation we don't cast if we're not a subclass of the input type. If the annotation is a tuple, then any type in the tuple will match." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform):\n", " def encodes(self, x:(Int,float)): return x/2\n", " def encodes(self, x:(str,list)): return str(x)+'1'\n", "\n", "f = A()\n", "test_eq_type(f(Int(2)), 1.)\n", "test_eq_type(f(2), 2)\n", "test_eq_type(f(Float(2.)), Float(1.))\n", "test_eq_type(f('a'), 'a1')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With return annotation `None` we get back whatever Python creates usually." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def func(x)->None: return x/2\n", "f = Transform(func)\n", "test_eq_type(f(2), 1.)\n", "test_eq_type(f(2.), 1.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since decodes has no return annotation, but encodes created an `Int` and we pass that result here to `decode`, we end up with an `Int`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def func(x): return Int(x+1)\n", "def dec (x): return x-1\n", "f = Transform(func,dec)\n", "t = f(1)\n", "test_eq_type(t, Int(2))\n", "test_eq_type(f.decode(t), Int(1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the transform has `split_idx` then it's only applied if `split_idx` param matches." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f.split_idx = 1\n", "test_eq(f(1, split_idx=1),2)\n", "test_eq_type(f(1, split_idx=0), 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transform takes lists as a whole and is applied to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): \n", " def encodes(self, xy): x,y=xy; return [x+y,y]\n", " def decodes(self, xy): x,y=xy; return [x-y,y]\n", "\n", "f = A()\n", "t = f([1,2])\n", "test_eq(t, [3,2])\n", "test_eq(f.decode(t), [1,2])\n", "f.split_idx = 1\n", "test_eq(f([1,2], split_idx=1), [3,2])\n", "test_eq(f([1,2], split_idx=0), [1,2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AL(Transform): pass\n", "@AL\n", "def encodes(self, x): return L(x_+1 for x_ in x)\n", "@AL\n", "def decodes(self, x): return L(x_-1 for x_ in x)\n", "\n", "f = AL()\n", "t = f([1,2])\n", "test_eq(t, [2,3])\n", "test_eq(f.decode(t), [1,2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms are applied to each element of a tuple." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def neg_int(x:numbers.Integral): return -x\n", "\n", "f = Transform(neg_int)\n", "test_eq(f((1,)), (-1,))\n", "test_eq(f((1.,)), (1.,))\n", "test_eq(f((1.,2,3.)), (1.,-2,3.))\n", "test_eq(f.decode((1,2)), (1,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(f.input_types, numbers.Integral)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class InplaceTransform(Transform):\n", " \"A `Transform` that modifies in-place and just returns whatever it's passed\"\n", " def _call(self, fn, x, split_idx=None, **kwargs):\n", " super()._call(fn,x,split_idx,**kwargs)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(InplaceTransform): pass\n", "@A\n", "def encodes(self, x:pd.Series): x.fillna(10, inplace=True)\n", "f = A()\n", "test_eq_type(f(pd.Series([1,2,None])),pd.Series([1,2,10]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(Transform): pass\n", "\n", "@B\n", "def encodes(self, x:int): return x+1\n", "@B\n", "def encodes(self, x:str): return x+'1'\n", "@B\n", "def encodes(self, x)->None: return str(x)+'!'\n", "\n", "b = B()\n", "test_eq(b([1]), '[1]!')\n", "test_eq(b((1,)), (2,))\n", "test_eq(b(('1',)), ('11',))\n", "test_eq(b([1.0]), '[1.0]!')\n", "test_eq(b.decode([2]), [2])\n", "assert pickle.loads(pickle.dumps(b))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@B\n", "def decodes(self, x:int): return x-1\n", "test_eq(b.decode((2,)), (1,))\n", "test_eq(b.decode(('2',)), ('2',))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Non-type-constrained functions are applied to all elements of a tuple." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x): return x+1\n", "@A\n", "def decodes(self, x): return x-1\n", "\n", "f = A()\n", "t = f((1,2.0))\n", "test_eq_type(t, (2,3.0))\n", "test_eq_type(f.decode(t), (1,2.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Type-constrained functions are applied to only matching elements of a tuple, and return annotations are only applied where matching." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(Transform):\n", " def encodes(self, x:int): return Int(x+1)\n", " def encodes(self, x:str): return x+'1'\n", " def decodes(self, x:Int): return x//2\n", "\n", "f = B()\n", "start = (1.,2,'3')\n", "t = f(start)\n", "test_eq_type(t, (1.,Int(3),'31'))\n", "test_eq(f.decode(t), (1.,Int(1),'31'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dispatching over tuples works recursively, by the way:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = B()\n", "start = (1.,(2,'3'))\n", "t = f(start)\n", "test_eq_type(t, (1.,(Int(3),'31')))\n", "test_eq(f.decode(t), (1.,(Int(1),'31')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same behavior also works with `typing` module type classes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): pass\n", "@A\n", "def encodes(self, x:numbers.Integral): return x+1\n", "@A\n", "def encodes(self, x:float): return x*3\n", "@A\n", "def decodes(self, x:int): return x-1\n", "\n", "f = A()\n", "start = 1.0\n", "t = f(start)\n", "test_eq(t, 3.)\n", "test_eq(f.decode(t), 3)\n", "\n", "start = (1.,2,3.)\n", "t = f(start)\n", "test_eq(t, (3.,3,9.))\n", "test_eq(f.decode(t), (3.,2,9.))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class DisplayedTransform(Transform):\n", " \"A transform with a `__repr__` that shows its attrs\"\n", "\n", " @property\n", " def name(self): return f\"{super().name} -- {getattr(self,'__stored_args__',{})}\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms normally are represented by just their class name and a list of encodes and decodes implementations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "A:\n", "encodes: (object,object) -> noop\n", "decodes: (object,object) -> noop" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class A(Transform): encodes,decodes = noop,noop\n", "f = A()\n", "f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `DisplayedTransform` will in addition show the contents of all attributes listed in the comma-delimited string `self.store_attrs`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "A -- {'a': 1, 'b': 2}:\n", "encodes: (object,object) -> noop\n", "decodes: " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class A(DisplayedTransform):\n", " encodes = noop\n", " def __init__(self, a, b=2):\n", " super().__init__()\n", " store_attr()\n", " \n", "A(a=1,b=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### ItemTransform -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ItemTransform(Transform):\n", " \"A transform that always take tuples as items\"\n", " _retain = True\n", " def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)\n", " def decode(self, x, **kwargs): return self._call1(x, 'decode', **kwargs)\n", " def _call1(self, x, name, **kwargs):\n", " if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)\n", " y = getattr(super(), name)(list(x), **kwargs)\n", " if not self._retain: return y\n", " if is_listy(y) and not isinstance(y, tuple): y = tuple(y)\n", " return retain_type(y, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`ItemTransform` is the class to use to opt out of the default behavior of `Transform`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AIT(ItemTransform): \n", " def encodes(self, xy): x,y=xy; return (x+y,y)\n", " def decodes(self, xy): x,y=xy; return (x-y,y)\n", " \n", "f = AIT()\n", "test_eq(f((1,2)), (3,2))\n", "test_eq(f.decode((3,2)), (1,2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you pass a special tuple subclass, the usual retain type behavior of `Transform` will keep it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _T(tuple): pass\n", "x = _T((1,2))\n", "test_eq_type(f(x), _T((3,2)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "f.split_idx = 0\n", "test_eq_type(f((1,2)), (1,2))\n", "test_eq_type(f((1,2), split_idx=0), (3,2))\n", "test_eq_type(f.decode((1,2)), (1,2))\n", "test_eq_type(f.decode((3,2), split_idx=0), (1,2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class Get(ItemTransform):\n", " _retain = False\n", " def encodes(self, x): return x[0]\n", " \n", "g = Get()\n", "test_eq(g([1,2,3]), 1)\n", "test_eq(g(L(1,2,3)), 1)\n", "test_eq(g(np.array([1,2,3])), 1)\n", "test_eq_type(g((['a'], ['b', 'c'])), ['a'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class A(ItemTransform): \n", " def encodes(self, x): return _T((x,x))\n", " def decodes(self, x): return _T(x)\n", " \n", "f = A()\n", "test_eq(type(f.decode((1,1))), _T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Func -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_func(t, name, *args, **kwargs):\n", " \"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined\"\n", " f = getattr(t, name, noop)\n", " return f if not (args or kwargs) else partial(f, *args, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This works for any kind of `t` supporting `getattr`, so a class or a module." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(get_func(operator, 'neg', 2)(), -2)\n", "test_eq(get_func(operator.neg, '__call__')(2), -2)\n", "test_eq(get_func(list, 'foobar')([2]), [2])\n", "a = [2,1]\n", "get_func(list, 'sort')(a)\n", "test_eq(a, [1,2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms are built with multiple-dispatch: a given function can have several methods depending on the type of the object received. This is done directly with the `TypeDispatch` module and type-annotation in `Transform`, but you can also use the following class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Func():\n", " \"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type\"\n", " def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs\n", " def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'\n", " def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)\n", " def __call__(self,t): return mapped(self._get, t)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can call the `Func` object on any module name or type, even a list of types. It will return the corresponding function (with a default to `noop` if nothing is found) or list of functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(Func('sqrt')(math), math.sqrt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _Sig():\n", " def __getattr__(self,k):\n", " def _inner(*args, **kwargs): return Func(k, *args, **kwargs)\n", " return _inner\n", "\n", "Sig = _Sig()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Sig[source]

\n", "\n", "> Sig(**\\*`args`**, **\\*\\*`kwargs`**)\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Sig, name=\"Sig\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Sig` is just sugar-syntax to create a `Func` object more easily with the syntax `Sig.name(*args, **kwargs)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = Sig.sqrt()\n", "test_eq(f(math), math.sqrt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pipeline -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):\n", " \"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order\"\n", " if reverse: tfms = reversed(tfms)\n", " for f in tfms:\n", " if not is_enc: f = f.decode\n", " x = f(x, **kwargs)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def to_int (x): return Int(x)\n", "def to_float(x): return Float(x)\n", "def double (x): return x*2\n", "def half(x)->None: return x/2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_compose(a, b, *fs): test_eq_type(compose_tfms(a, tfms=map(Transform,fs)), b)\n", "\n", "test_compose(1, Int(1), to_int)\n", "test_compose(1, Float(1), to_int,to_float)\n", "test_compose(1, Float(2), to_int,to_float,double)\n", "test_compose(2.0, 2.0, to_int,double,half)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform):\n", " def encodes(self, x:float): return Float(x+1)\n", " def decodes(self, x): return x-1\n", " \n", "tfms = [A(), Transform(math.sqrt)]\n", "t = compose_tfms(3., tfms=tfms)\n", "test_eq_type(t, Float(2.))\n", "test_eq(compose_tfms(t, tfms=tfms, is_enc=False), 1.)\n", "test_eq(compose_tfms(4., tfms=tfms, reverse=True), 3.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [A(), Transform(math.sqrt)]\n", "test_eq(compose_tfms((9,3.), tfms=tfms), (3,2.))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mk_transform(f):\n", " \"Convert function `f` to `Transform` if it isn't already one\"\n", " f = instantiate(f)\n", " return f if isinstance(f,(Transform,Pipeline)) else Transform(f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def gather_attrs(o, k, nm):\n", " \"Used in __getattr__ to collect all attrs `k` from `self.{nm}`\"\n", " if k.startswith('_') or k==nm: raise AttributeError(k)\n", " att = getattr(o,nm)\n", " res = [t for t in att.attrgot(k) if t is not None]\n", " if not res: raise AttributeError(k)\n", " return res[0] if len(res)==1 else L(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def gather_attr_names(o, nm):\n", " \"Used in __dir__ to collect all attrs `k` from `self.{nm}`\"\n", " return L(getattr(o,nm)).map(dir).concat().unique()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Pipeline:\n", " \"A pipeline of composed (for encode/decode) transforms, setup with types\"\n", " def __init__(self, funcs=None, split_idx=None):\n", " self.split_idx,self.default = split_idx,None\n", " if funcs is None: funcs = []\n", " if isinstance(funcs, Pipeline): self.fs = funcs.fs\n", " else:\n", " if isinstance(funcs, Transform): funcs = [funcs]\n", " self.fs = L(ifnone(funcs,[noop])).map(mk_transform).sorted(key='order')\n", " for f in self.fs:\n", " name = camel2snake(type(f).__name__)\n", " a = getattr(self,name,None)\n", " if a is not None: f = L(a)+f\n", " setattr(self, name, f)\n", "\n", " def setup(self, items=None, train_setup=False):\n", " tfms = self.fs[:]\n", " self.fs.clear()\n", " for t in tfms: self.add(t,items, train_setup)\n", "\n", " def add(self,t, items=None, train_setup=False):\n", " t.setup(items, train_setup)\n", " self.fs.append(t)\n", "\n", " def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)\n", " def __repr__(self): return f\"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}\"\n", " def __getitem__(self,i): return self.fs[i]\n", " def __setstate__(self,data): self.__dict__.update(data)\n", " def __getattr__(self,k): return gather_attrs(self, k, 'fs')\n", " def __dir__(self): return super().__dir__() + gather_attr_names(self, 'fs')\n", "\n", " def decode (self, o, full=True):\n", " if full: return compose_tfms(o, tfms=self.fs, is_enc=False, reverse=True, split_idx=self.split_idx)\n", " #Not full means we decode up to the point the item knows how to show itself.\n", " for f in reversed(self.fs):\n", " if self._is_showable(o): return o\n", " o = f.decode(o, split_idx=self.split_idx)\n", " return o\n", "\n", " def show(self, o, ctx=None, **kwargs):\n", " o = self.decode(o, full=False)\n", " o1 = (o,) if not _is_tuple(o) else o\n", " if hasattr(o, 'show'): ctx = o.show(ctx=ctx, **kwargs)\n", " else:\n", " for o_ in o1:\n", " if hasattr(o_, 'show'): ctx = o_.show(ctx=ctx, **kwargs)\n", " return ctx\n", "\n", " def _is_showable(self, o):\n", " if hasattr(o, 'show'): return True\n", " if _is_tuple(o): return all(hasattr(o_, 'show') for o_ in o)\n", " return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "add_docs(Pipeline,\n", " __call__=\"Compose `__call__` of all `fs` on `o`\",\n", " decode=\"Compose `decode` of all `fs` on `o`\",\n", " show=\"Show `o`, a single item from a tuple, decoding as needed\",\n", " add=\"Add transform `t`\",\n", " setup=\"Call each tfm's `setup` in order\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Pipeline` is a wrapper for `compose_tfm`. You can pass instances of `Transform` or regular functions in `funcs`, the `Pipeline` will wrap them all in `Transform` (and instantiate them if needed) during the initialization. It handles the transform `setup` by adding them one at a time and calling setup on each, goes through them in order in `__call__` or `decode` and can `show` an object by applying decoding the transforms up until the point it gets an object that knows how to show itself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Empty pipeline is noop\n", "pipe = Pipeline()\n", "test_eq(pipe(1), 1)\n", "test_eq(pipe((1,)), (1,))\n", "# Check pickle works\n", "assert pickle.loads(pickle.dumps(pipe))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class IntFloatTfm(Transform):\n", " def encodes(self, x): return Int(x)\n", " def decodes(self, x): return Float(x)\n", " foo=1\n", "\n", "int_tfm=IntFloatTfm()\n", "\n", "def neg(x): return -x\n", "neg_tfm = Transform(neg, neg)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([neg_tfm, int_tfm])\n", "\n", "start = 2.0\n", "t = pipe(start)\n", "test_eq_type(t, Int(-2))\n", "test_eq_type(pipe.decode(t), Float(start))\n", "test_stdout(lambda:pipe.show(t), '-2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([neg_tfm, int_tfm])\n", "t = pipe(start)\n", "test_stdout(lambda:pipe.show(pipe((1.,2.))), '-1\\n-2')\n", "test_eq(pipe.foo, 1)\n", "assert 'foo' in dir(pipe)\n", "assert 'int_float_tfm' in dir(pipe)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms are available as attributes named with the snake_case version of the names of their types. Attributes in transforms can be directly accessed as attributes of the pipeline." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(pipe.int_float_tfm, int_tfm)\n", "test_eq(pipe.foo, 1)\n", "\n", "pipe = Pipeline([int_tfm, int_tfm])\n", "pipe.int_float_tfm\n", "test_eq(pipe.int_float_tfm[0], int_tfm)\n", "test_eq(pipe.foo, [1,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check opposite order\n", "pipe = Pipeline([int_tfm,neg_tfm])\n", "t = pipe(start)\n", "test_eq(t, -2)\n", "test_stdout(lambda:pipe.show(t), '-2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform):\n", " def encodes(self, x): return int(x)\n", " def decodes(self, x): return Float(x)\n", "\n", "pipe = Pipeline([neg_tfm, A])\n", "t = pipe(start)\n", "test_eq_type(t, -2)\n", "test_eq_type(pipe.decode(t), Float(start))\n", "test_stdout(lambda:pipe.show(t), '-2.0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s2 = (1,2)\n", "pipe = Pipeline([neg_tfm, A])\n", "t = pipe(s2)\n", "test_eq_type(t, (-1,-2))\n", "test_eq_type(pipe.decode(t), (Float(1.),Float(2.)))\n", "test_stdout(lambda:pipe.show(t), '-1.0\\n-2.0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(Transform):\n", " def encodes(self, x): return x+1\n", " def decodes(self, x): return x-1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def f1(x:ArrayImage): return -x\n", "def f2(x): return Image.open(x).resize((128,128))\n", "def f3(x:Image.Image): return(ArrayImage(array(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([f2,f3,f1])\n", "t = pipe(TEST_IMAGE)\n", "test_eq(type(t), ArrayImage)\n", "test_eq(t, -array(f3(f2(TEST_IMAGE))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "pipe = Pipeline([f2,f3])\n", "t = pipe(TEST_IMAGE)\n", "ax = pipe.show(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fig_exists(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Check filtering is properly applied\n", "add1 = B()\n", "add1.split_idx = 1\n", "pipe = Pipeline([neg_tfm, A(), add1])\n", "test_eq(pipe(start), -2)\n", "pipe.split_idx=1\n", "test_eq(pipe(start), -1)\n", "pipe.split_idx=0\n", "test_eq(pipe(start), -2)\n", "for t in [None, 0, 1]:\n", " pipe.split_idx=t\n", " test_eq(pipe.decode(pipe(start)), start)\n", " test_stdout(lambda: pipe.show(pipe(start)), \"-2.0\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def neg(x): return -x\n", "test_eq(type(mk_transform(neg)), Transform)\n", "test_eq(type(mk_transform(math.sqrt)), Transform)\n", "test_eq(type(mk_transform(lambda a:a*2)), Transform)\n", "test_eq(type(mk_transform(Pipeline([neg]))), Pipeline)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TODO: method examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Pipeline.__call__[source]

\n", "\n", "> Pipeline.__call__(**`o`**)\n", "\n", "Compose `__call__` of all `fs` on `o`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Pipeline.__call__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Pipeline.decode[source]

\n", "\n", "> Pipeline.decode(**`o`**, **`full`**=*`True`*)\n", "\n", "Compose `decode` of all `fs` on `o`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Pipeline.decode)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Pipeline.setup[source]

\n", "\n", "> Pipeline.setup(**`items`**=*`None`*, **`train_setup`**=*`False`*)\n", "\n", "Call each tfm's `setup` in order" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Pipeline.setup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "During the setup, the `Pipeline` starts with no transform and adds them one at a time, so that during its setup, each transform gets the items processed up to its point and not after." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test is with TfmdList" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_foundation.ipynb.\n", "Converted 02_utils.ipynb.\n", "Converted 03_dispatch.ipynb.\n", "Converted 04_transform.ipynb.\n", "Converted index.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }