{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#default_exp transform"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"from fastcore.imports import *\n",
"from fastcore.foundation import *\n",
"from fastcore.utils import *\n",
"from fastcore.dispatch import *\n",
"import inspect"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nbdev.showdoc import *\n",
"from fastcore.test import *\n",
"from fastcore.nb_imports import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transforms\n",
"\n",
"> Definition of `Transform` and `Pipeline`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The classes here provide functionality for creating a composition of *partially reversible functions*. By \"partially reversible\" we mean that a transform can be `decode`d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).\n",
"\n",
"Classes are also provided and for composing transforms, and mapping them over collections. `Pipeline` is a transform which composes several `Transform`, knowing how to decode them or show an encoded item."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Transform -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"_tfm_methods = 'encodes','decodes','setups'\n",
"\n",
"class _TfmDict(dict):\n",
" def __setitem__(self,k,v):\n",
" if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n",
" if k not in self: super().__setitem__(k,TypeDispatch())\n",
" self[k].add(v)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"class _TfmMeta(type):\n",
" def __new__(cls, name, bases, dict):\n",
" res = super().__new__(cls, name, bases, dict)\n",
" for nm in _tfm_methods:\n",
" base_td = [getattr(b,nm,None) for b in bases]\n",
" if nm in res.__dict__: getattr(res,nm).bases = base_td\n",
" else: setattr(res, nm, TypeDispatch(bases=base_td))\n",
" res.__signature__ = inspect.signature(res.__init__)\n",
" return res\n",
"\n",
" def __call__(cls, *args, **kwargs):\n",
" f = args[0] if args else None\n",
" n = getattr(f,'__name__',None)\n",
" if callable(f) and n in _tfm_methods:\n",
" getattr(cls,n).add(f)\n",
" return f\n",
" return super().__call__(*args, **kwargs)\n",
"\n",
" @classmethod\n",
" def __prepare__(cls, name, bases): return _TfmDict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"def _get_name(o):\n",
" if hasattr(o,'__qualname__'): return o.__qualname__\n",
" if hasattr(o,'__name__'): return o.__name__\n",
" return o.__class__.__name__"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"def _is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"class Transform(metaclass=_TfmMeta):\n",
" \"Delegates (`__call__`,`decode`,`setup`) to (encodes
,decodes
,setups
) if `split_idx` matches\"\n",
" split_idx,init_enc,order,train_setup = None,None,0,None\n",
" def __init__(self, enc=None, dec=None, split_idx=None, order=None):\n",
" self.split_idx = ifnone(split_idx, self.split_idx)\n",
" if order is not None: self.order=order\n",
" self.init_enc = enc or dec\n",
" if not self.init_enc: return\n",
"\n",
" self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()\n",
" if enc:\n",
" self.encodes.add(enc)\n",
" self.order = getattr(enc,'order',self.order)\n",
" if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())\n",
" self._name = _get_name(enc)\n",
" if dec: self.decodes.add(dec)\n",
"\n",
" @property\n",
" def name(self): return getattr(self, '_name', _get_name(self))\n",
" def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)\n",
" def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)\n",
" def __repr__(self): return f'{self.name}:\\nencodes: {self.encodes}decodes: {self.decodes}'\n",
"\n",
" def setup(self, items=None, train_setup=False):\n",
" train_setup = train_setup if self.train_setup is None else self.train_setup\n",
" return self.setups(getattr(items, 'train', items) if train_setup else items)\n",
"\n",
" def _call(self, fn, x, split_idx=None, **kwargs):\n",
" if split_idx!=self.split_idx and self.split_idx is not None: return x\n",
" return self._do_call(getattr(self, fn), x, **kwargs)\n",
"\n",
" def _do_call(self, f, x, **kwargs):\n",
" if not _is_tuple(x):\n",
" if f is None: return x\n",
" ret = f.returns(x) if hasattr(f,'returns') else None\n",
" return retain_type(f(x, **kwargs), x, ret)\n",
" res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)\n",
" return retain_type(res, x)\n",
"\n",
"add_docs(Transform, decode=\"Delegate to decodes
to undo transform\", setup=\"Delegate to setups
to set up transform\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [
{
"data": {
"text/markdown": [
"
class
Transform
[source]Transform
(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*)\n",
"\n",
"Delegates (`__call__`,`decode`,`setup`) to (encodes
,decodes
,setups
) if `split_idx` matches"
],
"text/plain": [
"decodes
method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the `decode` method that is used to decode will be applied over each element of a tuple separately.\n",
"- **Type propagation** - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like `ArrayImage` which is a thin wrapper of pytorch's `Tensor`. You can opt out of this behavior by adding `->None` return type annotation.\n",
"- **Preprocessing** - The `setup` method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.\n",
"- **Filtering based on the dataset type** - By setting the `split_idx` flag you can make the transform be used only in a specific `DataSource` subset like in training, but not validation.\n",
"- **Ordering** - You can set the `order` attribute which the `Pipeline` uses when it needs to merge two lists of transforms.\n",
"- **Appending new behavior with decorators** - You can easily extend an existing `Transform` by creating encodes
or decodes
methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining a `Transform`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are a few ways to create a transform with different ratios of simplicity to flexibility.\n",
"- **Extending the `Transform` class** - Use inheritence to implement the methods you want.\n",
"- **Passing methods to the constructor** - Instantiate the `Transform` class and pass your functions as `enc` and `dec` arguments.\n",
"- **@Transform decorator** - Turn any function into a `Transform` by just adding a decorator - very straightforward if all you need is a single encodes
implementation.\n",
"- **Passing a function to fastai APIs** - Same as above, but when passing a function to other transform aware classes like `Pipeline` or `TfmdDS` you don't even need a decorator. Your function will get converted to a `Transform` automatically."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A simple way to create a `Transform` is to pass a function to the constructor. In the below example, we pass an anonymous function that does integer division by 2:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = Transform(lambda o:o//2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you call this transform, it will apply the transformation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq_type(f(2), 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another way to define a Transform is to extend the `Transform` class:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class A(Transform): pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, to enable your transform to do something, you have to define an encodes
method. Note that we can use the class name as a decorator to add this method to the original class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@A\n",
"def encodes(self, x): return x+1\n",
"\n",
"f1 = A()\n",
"test_eq(f1(1), 2) # f1(1) is the same as f1.encode(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to adding an encodes
method, we can also add a decodes
method. This enables you to call the `decode` method (without an s). For more information about the purpose of decodes
, see the discussion about Reversibility in [the above section](#The-main-Transform-features)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like with encodes, you can add a decodes
method to the original class by using the class name as a decorator:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class B(A): pass\n",
"\n",
"@B\n",
"def decodes(self, x): return x-1\n",
"\n",
"f2 = B()\n",
"test_eq(f2.decode(2), 1)\n",
"\n",
"test_eq(f2(1), 2) # uses A's encode method from the parent class"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you do not define an encodes
or decodes
method the original value will be returned:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class _Tst(Transform): pass \n",
"\n",
"f3 = _Tst() # no encodes or decodes method have been defined\n",
"test_eq_type(f3.decode(2.0), 2.0)\n",
"test_eq_type(f3(2), 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Defining Transforms With A Decorator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`Transform` can be used as a decorator to turn a function into a `Transform`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@Transform\n",
"def f(x): return x//2\n",
"test_eq_type(f(2), 1)\n",
"test_eq_type(f.decode(2.0), 2.0)\n",
"\n",
"@Transform\n",
"def f(x): return x*2\n",
"test_eq_type(f(2), 4)\n",
"test_eq_type(f.decode(2.0), 2.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Typed Dispatch and Transforms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also apply different transformations depending on the type of the input passed by using `TypedDispatch`. `TypedDispatch` automatically works with `Transform` when using type hints:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class A(Transform): pass\n",
"\n",
"@A\n",
"def encodes(self, x:int): return x//2\n",
"\n",
"@A\n",
"def encodes(self, x:float): return x+1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we pass in an `int`, this calls the first encodes method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = A()\n",
"test_eq_type(f(3), 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we pass in a `float`, this calls the second encodes method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq_type(f(2.), 3.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we pass in a type that is not specified in encodes
, the original value is returned:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(f('a'), 'a')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the type annotation is a tuple, then any type in the tuple will match:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class MyClass(int): pass\n",
"\n",
"class A(Transform):\n",
" def encodes(self, x:(MyClass,float)): return x/2\n",
" def encodes(self, x:(str,list)): return str(x)+'_1'\n",
"\n",
"f = A()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The below two examples match the first encodes, with a type of `MyClass` and `float`, respectively:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(f(MyClass(2)), 1.) # input is of type MyClass \n",
"test_eq(f(6.0), 3.0) # input is of type float"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next two examples match the second `encodes` method, with a type of `str` and `list`, respectively:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(f('a'), 'a_1') # input is of type str\n",
"test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Casting Types With Transform"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without any intervention it is easy for operations to change types in Python. For example, `FloatSubclass` (defined below) becomes a `float` after performing multiplication:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class FloatSubclass(float): pass\n",
"test_eq_type(FloatSubclass(3.0) * 2, 6.0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This behavior is often not desirable when performing transformations on data. Therefore, `Transform` will attempt to cast the output to be of the same type as the input by default. In the below example, the output will be cast to a `FloatSubclass` type to match the type of the input:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@Transform\n",
"def f(x): return x*2\n",
"\n",
"test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can optionally turn off casting by annotating the transform function with a return type of `None`: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@Transform\n",
"def f(x)-> None: return x*2 # Same transform as above, but with a -> None annotation\n",
"\n",
"test_eq_type(f(FloatSubclass(3.0)), 6.0) # Casting is turned off because of -> None annotation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, `Transform` will only cast output back to the input type when the input is a subclass of the output. In the below example, the input is of type `FloatSubclass` which is not a subclass of the output which is of type `str`. Therefore, the output doesn't get cast back to `FloatSubclass` and stays as type `str`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@Transform\n",
"def f(x): return str(x)\n",
" \n",
"test_eq_type(f(Float(2.)), '2.0')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like encodes
, the decodes
method will cast outputs to match the input type in the same way. In the below example, the output of decodes
remains of type `MySubclass`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class MySubclass(int): pass\n",
"\n",
"def enc(x): return MySubclass(x+1)\n",
"def dec(x): return x-1\n",
"\n",
"\n",
"f = Transform(enc,dec)\n",
"t = f(1) # t is of type MySubclass\n",
"test_eq_type(f.decode(t), MySubclass(1)) # the output of decode is cast to MySubclass to match the input type."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Apply Transforms On Subsets With `split_idx`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can apply transformations to subsets of data by specifying a `split_idx` property. If a transform has a `split_idx` then it's only applied if the `split_idx` param matches. In the below example, we set `split_idx` equal to `1`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def enc(x): return x+1\n",
"def dec(x): return x-1\n",
"f = Transform(enc,dec)\n",
"f.split_idx = 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The transformations are applied when a matching `split_idx` parameter is passed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(f(1, split_idx=1),2)\n",
"test_eq(f.decode(2, split_idx=1),1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On the other hand, transformations are ignored when the `split_idx` parameter does not match:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(f(1, split_idx=0), 1)\n",
"test_eq(f.decode(2, split_idx=0), 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Transforms on Lists"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transform operates on lists as a whole, **not element-wise**:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class A(Transform):\n",
" def encodes(self, x): return dict(x)\n",
" def decodes(self, x): return list(x.items())\n",
" \n",
"f = A()\n",
"_inp = [(1,2), (3,4)]\n",
"t = f(_inp)\n",
"\n",
"test_eq(t, dict(_inp))\n",
"test_eq(f.decodes(t), _inp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"f.split_idx = 1\n",
"test_eq(f(_inp, split_idx=1), dict(_inp))\n",
"test_eq(f(_inp, split_idx=0), _inp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want a transform to operate on a list elementwise, you must implement this appropriately in the encodes
and decodes
methods:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AL(Transform): pass\n",
"\n",
"@AL\n",
"def encodes(self, x): return [x_+1 for x_ in x]\n",
"\n",
"@AL\n",
"def decodes(self, x): return [x_-1 for x_ in x]\n",
"\n",
"f = AL()\n",
"t = f([1,2])\n",
"\n",
"test_eq(t, [2,3])\n",
"test_eq(f.decode(t), [1,2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Transforms on Tuples\n",
"\n",
"Unlike lists, `Transform` operates on tuples element-wise."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def neg_int(x): return -x\n",
"f = Transform(neg_int)\n",
"\n",
"test_eq(f((1,2,3)), (-1,-2,-3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transforms will also apply `TypedDispatch` element-wise on tuples when an input type annotation is specified. In the below example, the values `1.0` and `3.0` are ignored because they are of type `float`, not `int`: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def neg_int(x:int): return -x\n",
"f = Transform(neg_int)\n",
"\n",
"test_eq(f((1.0, 2, 3.0)), (1.0, -2, 3.0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"test_eq(f((1,)), (-1,))\n",
"test_eq(f((1.,)), (1.,))\n",
"test_eq(f.decode((1,2)), (1,2))\n",
"test_eq(f.input_types, int)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another example of how `Transform` can use `TypedDispatch` with tuples is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class B(Transform): pass\n",
"\n",
"@B\n",
"def encodes(self, x:int): return x+1\n",
"\n",
"@B\n",
"def encodes(self, x:str): return x+'hello'\n",
"\n",
"@B\n",
"def encodes(self, x)->None: return str(x)+'!'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the input is not an `int` or `str`, the third `encodes` method will apply:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"b = B()\n",
"test_eq(b([1]), '[1]!') \n",
"test_eq(b([1.0]), '[1.0]!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, if the input is a tuple, then the appropriate method will apply according to the type of each element in the tuple:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(b(('1',)), ('1hello',))\n",
"test_eq(b((1,2)), (2,3))\n",
"test_eq(b(('a',1.0)), ('ahello','1.0!'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"@B\n",
"def decodes(self, x:int): return x-1\n",
"\n",
"test_eq(b.decode((2,)), (1,))\n",
"test_eq(b.decode(('2',)), ('2',))\n",
"assert pickle.loads(pickle.dumps(b))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dispatching over tuples works recursively, by the way:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class B(Transform):\n",
" def encodes(self, x:int): return x+1\n",
" def encodes(self, x:str): return x+'_hello'\n",
" def decodes(self, x:int): return x-1\n",
" def decodes(self, x:str): return x.replace('_hello', '')\n",
"\n",
"f = B()\n",
"start = (1.,(2,'3'))\n",
"t = f(start)\n",
"test_eq_type(t, (1.,(3,'3_hello')))\n",
"test_eq(f.decode(t), start)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dispatching also works with `typing` module type classes, like `numbers.integral`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@Transform\n",
"def f(x:numbers.Integral): return x+1\n",
"\n",
"t = f((1,'1',1))\n",
"test_eq(t, (2, '1', 2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"class InplaceTransform(Transform):\n",
" \"A `Transform` that modifies in-place and just returns whatever it's passed\"\n",
" def _call(self, fn, x, split_idx=None, **kwargs):\n",
" super()._call(fn,x,split_idx,**kwargs)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class A(InplaceTransform): pass\n",
"\n",
"@A\n",
"def encodes(self, x:pd.Series): x.fillna(10, inplace=True)\n",
" \n",
"f = A()\n",
"\n",
"test_eq_type(f(pd.Series([1,2,None])),pd.Series([1,2,10],dtype=np.float64)) #fillna fills with floats."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# export\n",
"class DisplayedTransform(Transform):\n",
" \"A transform with a `__repr__` that shows its attrs\"\n",
"\n",
" @property\n",
" def name(self): return f\"{super().name} -- {getattr(self,'__stored_args__',{})}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transforms normally are represented by just their class name and a list of encodes and decodes implementations:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"A:\n",
"encodes: (object,object) -> noop\n",
"decodes: (object,object) -> noop"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class A(Transform): encodes,decodes = noop,noop\n",
"f = A()\n",
"f"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A `DisplayedTransform` will in addition show the contents of all attributes listed in the comma-delimited string `self.store_attrs`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"A -- {'a': 1, 'b': 2}:\n",
"encodes: (object,object) -> noop\n",
"decodes: "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class A(DisplayedTransform):\n",
" encodes = noop\n",
" def __init__(self, a, b=2):\n",
" super().__init__()\n",
" store_attr()\n",
" \n",
"A(a=1,b=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### ItemTransform -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"class ItemTransform(Transform):\n",
" \"A transform that always take tuples as items\"\n",
" _retain = True\n",
" def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)\n",
" def decode(self, x, **kwargs): return self._call1(x, 'decode', **kwargs)\n",
" def _call1(self, x, name, **kwargs):\n",
" if not _is_tuple(x): return getattr(super(), name)(x, **kwargs)\n",
" y = getattr(super(), name)(list(x), **kwargs)\n",
" if not self._retain: return y\n",
" if is_listy(y) and not isinstance(y, tuple): y = tuple(y)\n",
" return retain_type(y, x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`ItemTransform` is the class to use to opt out of the default behavior of `Transform`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AIT(ItemTransform): \n",
" def encodes(self, xy): x,y=xy; return (x+y,y)\n",
" def decodes(self, xy): x,y=xy; return (x-y,y)\n",
" \n",
"f = AIT()\n",
"test_eq(f((1,2)), (3,2))\n",
"test_eq(f.decode((3,2)), (1,2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you pass a special tuple subclass, the usual retain type behavior of `Transform` will keep it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class _T(tuple): pass\n",
"x = _T((1,2))\n",
"test_eq_type(f(x), _T((3,2)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"f.split_idx = 0\n",
"test_eq_type(f((1,2)), (1,2))\n",
"test_eq_type(f((1,2), split_idx=0), (3,2))\n",
"test_eq_type(f.decode((1,2)), (1,2))\n",
"test_eq_type(f.decode((3,2), split_idx=0), (1,2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"class Get(ItemTransform):\n",
" _retain = False\n",
" def encodes(self, x): return x[0]\n",
" \n",
"g = Get()\n",
"test_eq(g([1,2,3]), 1)\n",
"test_eq(g(L(1,2,3)), 1)\n",
"test_eq(g(np.array([1,2,3])), 1)\n",
"test_eq_type(g((['a'], ['b', 'c'])), ['a'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"#hide\n",
"class A(ItemTransform): \n",
" def encodes(self, x): return _T((x,x))\n",
" def decodes(self, x): return _T(x)\n",
" \n",
"f = A()\n",
"test_eq(type(f.decode((1,1))), _T)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Func -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"def get_func(t, name, *args, **kwargs):\n",
" \"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined\"\n",
" f = getattr(t, name, noop)\n",
" return f if not (args or kwargs) else partial(f, *args, **kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This works for any kind of `t` supporting `getattr`, so a class or a module."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(get_func(operator, 'neg', 2)(), -2)\n",
"test_eq(get_func(operator.neg, '__call__')(2), -2)\n",
"test_eq(get_func(list, 'foobar')([2]), [2])\n",
"a = [2,1]\n",
"get_func(list, 'sort')(a)\n",
"test_eq(a, [1,2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transforms are built with multiple-dispatch: a given function can have several methods depending on the type of the object received. This is done directly with the `TypeDispatch` module and type-annotation in `Transform`, but you can also use the following class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"class Func():\n",
" \"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type\"\n",
" def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs\n",
" def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'\n",
" def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)\n",
" def __call__(self,t): return mapped(self._get, t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can call the `Func` object on any module name or type, even a list of types. It will return the corresponding function (with a default to `noop` if nothing is found) or list of functions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(Func('sqrt')(math), math.sqrt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"#export\n",
"class _Sig():\n",
" def __getattr__(self,k):\n",
" def _inner(*args, **kwargs): return Func(k, *args, **kwargs)\n",
" return _inner\n",
"\n",
"Sig = _Sig()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [
{
"data": {
"text/markdown": [
"Sig
[source]Sig
(**\\*`args`**, **\\*\\*`kwargs`**)\n",
"\n"
],
"text/plain": [
"