{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_001b import *\n", "import sys, PIL, matplotlib.pyplot as plt, itertools, math, random, collections, torch\n", "import scipy.stats, scipy.special\n", "\n", "from enum import Enum, IntEnum\n", "from torch import tensor, Tensor, FloatTensor, LongTensor, ByteTensor, DoubleTensor, HalfTensor, ShortTensor\n", "from operator import itemgetter, attrgetter\n", "from numpy import cos, sin, tan, tanh, log, exp\n", "from dataclasses import field\n", "from functools import reduce\n", "from collections import defaultdict, abc, namedtuple, Iterable\n", "from typing import Tuple, Hashable, Mapping, Dict\n", "\n", "import mimetypes, abc, functools\n", "from abc import abstractmethod, abstractproperty" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import show_doc as sd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# CIFAR subset data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we want to view our data to check if everything is how we expect it to be." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATA_PATH = Path('data')\n", "PATH = DATA_PATH/'cifar10_dog_air'\n", "TRAIN_PATH = PATH/'train'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dog_fn = list((TRAIN_PATH/'dog').iterdir())[0]\n", "dog_image = PIL.Image.open(dog_fn)\n", "dog_image.resize((256,256))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "air_fn = list((TRAIN_PATH/'airplane').iterdir())[1]\n", "air_image = PIL.Image.open(air_fn)\n", "air_image.resize((256,256))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple Dataset/Dataloader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will build a Dataset class for our image files. A Dataset class needs to have two functions: `__len__` and `__getitem__`. Our `ImageDataset` class additionally gets image files from their respective directories and transforms them to tensors." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def image2np(image:Tensor)->np.ndarray:\n", " \"Convert from torch style `image` to numpy/matplotlib style\"\n", " res = image.cpu().permute(1,2,0).numpy()\n", " return res[...,0] if res.shape[2]==1 else res\n", "\n", "def show_image(img:Tensor, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, \n", " title:Optional[str]=None, cmap:str='binary', alpha:Optional[float]=None)->plt.Axes:\n", " \"Plot tensor `img` using matplotlib axis `ax`. `figsize`,`axis`,`title`,`cmap` and `alpha` pass to `ax.imshow`\"\n", " if ax is None: fig,ax = plt.subplots(figsize=figsize)\n", " ax.imshow(image2np(img), cmap=cmap, alpha=alpha)\n", " if hide_axis: ax.axis('off')\n", " if title: ax.set_title(title)\n", " return ax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Image():\n", " def __init__(self, px): self.px = px\n", " def show(self, ax=None, **kwargs): return show_image(self.px, ax=ax, **kwargs)\n", " @property\n", " def data(self): return self.px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "FilePathList = Collection[Path]\n", "TensorImage = Tensor\n", "NPImage = np.ndarray\n", "\n", "def find_classes(folder:Path)->FilePathList:\n", " \"Return class subdirectories in imagenet style train `folder`\"\n", " classes = [d for d in folder.iterdir()\n", " if d.is_dir() and not d.name.startswith('.')]\n", " assert(len(classes)>0)\n", " return sorted(classes, key=lambda d: d.name)\n", "\n", "image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))\n", "\n", "def get_image_files(c:Path, check_ext:bool=True)->FilePathList:\n", " \"Return list of files in `c` that are images. `check_ext` will filter to `image_extensions`.\"\n", " return [o for o in list(c.iterdir())\n", " if not (o.name.startswith('.') or o.is_dir()\n", " or (check_ext and o.suffix not in image_extensions))]\n", "\n", "def pil2tensor(image:NPImage)->TensorImage:\n", " \"Convert PIL style `image` array to torch style image tensor `get_image_files`\"\n", " arr = torch.ByteTensor(torch.ByteStorage.from_buffer(image.tobytes()))\n", " arr = arr.view(image.size[1], image.size[0], -1)\n", " return arr.permute(2,0,1)\n", "\n", "PathOrStr = Union[Path,str]\n", "def open_image(fn:PathOrStr):\n", " \"Return `Image` object created from image in file `fn`\"\n", " x = PIL.Image.open(fn).convert('RGB')\n", " return Image(pil2tensor(x).float().div_(255))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "NPArrayableList = Collection[Union[np.ndarray, list]]\n", "NPArrayMask = np.ndarray\n", "SplitArrayList = List[Tuple[np.ndarray,np.ndarray]]\n", "\n", "def arrays_split(mask:NPArrayMask, *arrs:NPArrayableList)->SplitArrayList:\n", " \"Given `arrs` is [a,b,...] and `mask`index - return[(a[mask],a[~mask]),(b[mask],b[~mask]),...]\"\n", " mask = array(mask)\n", " return list(zip(*[(a[mask],a[~mask]) for a in map(np.array, arrs)]))\n", "\n", "def random_split(valid_pct:float, *arrs:NPArrayableList)->SplitArrayList:\n", " \"Randomly `array_split` with `valid_pct` ratio. good for creating validation set.\"\n", " is_train = np.random.uniform(size=(len(arrs[0]),)) > valid_pct\n", " return arrays_split(is_train, *arrs)\n", "\n", "class DatasetBase(Dataset):\n", " \"Base class for all fastai datasets\"\n", " def __len__(self): return len(self.x)\n", " @property\n", " def c(self): \n", " \"Number of classes expressed by dataset y variable\"\n", " return self.y.shape[-1] if len(self.y.shape)>1 else 1\n", " def __repr__(self): return f'{type(self).__name__} of len {len(self)}'\n", "\n", "class LabelDataset(DatasetBase):\n", " \"Base class for fastai datasets that do classification\"\n", " @property\n", " def c(self): \n", " \"Number of classes expressed by dataset y variable\"\n", " return len(self.classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "ImgLabel = str\n", "ImgLabels = Collection[ImgLabel]\n", "Classes = Collection[Any]\n", "\n", "class ImageDataset(LabelDataset):\n", " \"Dataset for folders of images in style {folder}/{class}/{images}\"\n", " def __init__(self, fns:FilePathList, labels:ImgLabels, classes:Optional[Classes]=None):\n", " self.classes = ifnone(classes, list(set(labels)))\n", " self.class2idx = {v:k for k,v in enumerate(self.classes)}\n", " self.x = np.array(fns)\n", " self.y = np.array([self.class2idx[o] for o in labels], dtype=np.int64)\n", " \n", " def __getitem__(self,i): return open_image(self.x[i]),self.y[i]\n", " \n", " @staticmethod\n", " def _folder_files(folder:Path, label:ImgLabel, check_ext=True)->Tuple[FilePathList,ImgLabels]:\n", " \"From `folder` return image files and labels. The labels are all `label`. `check_ext` means only image files\"\n", " fnames = get_image_files(folder, check_ext=check_ext)\n", " return fnames,[label]*len(fnames)\n", " \n", " @classmethod\n", " def from_single_folder(cls, folder:PathOrStr, classes:Classes, check_ext=True):\n", " \"Typically used for test set. label all images in `folder` with `classes[0]`\"\n", " fns,labels = cls._folder_files(folder, classes[0], check_ext=check_ext)\n", " return cls(fns, labels, classes=classes)\n", "\n", " @classmethod\n", " def from_folder(cls, folder:Path, classes:Optional[Classes]=None, \n", " valid_pct:float=0., check_ext:bool=True) -> Union['ImageDataset', List['ImageDataset']]:\n", " \"Dataset of `classes` labeled images in `folder`. Optional `valid_pct` split validation set.\"\n", " if classes is None: classes = [cls.name for cls in find_classes(folder)]\n", " \n", " fns,labels = [],[]\n", " for cl in classes:\n", " f,l = cls._folder_files(folder/cl, cl, check_ext=check_ext)\n", " fns+=f; labels+=l\n", " \n", " if valid_pct==0.: return cls(fns, labels, classes=classes)\n", " return [cls(*a, classes=classes) for a in random_split(valid_pct, fns, labels)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sd(ImageDataset.from_folder, arg_comments={\"folder\": \"Folder containing subfolders, one for each class\"})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# Data augmentation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are going to augment our data to increase the size of our training set with artificial images. These new images are basically \"free\" data that we can use in our training to help our model generalize better (reduce overfitting)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lighting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will start by changing the **brightness** and **contrast** of our images." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Method" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Brightness**\n", "\n", "Brightness refers to where does our image stand on the dark-light spectrum. Brightness is applied by adding a positive constant to each of the image's channels. This works because each of the channels in an image goes from 0 (darkest) to 255 (brightest) in a dark-light continum. (0, 0, 0) is black (total abscence of light) and (255, 255, 255) is white (pure light). You can check how this works by experimenting by yourself [here](https://www.w3schools.com/colors/colors_rgb.asp).\n", "\n", "_Parameters_\n", "\n", "1. **Change** How much brightness do we want to add to (or take from) the image.\n", "\n", " Domain: Real numbers\n", " \n", "**Contrast**\n", "\n", "Contrast refers to how sharp a distinction there is between brighter and darker sections of our image. To increase contrast we need darker pixels to be darker and lighter pixels to be lighter. In other words, we would like channels with a value smaller than 128 to decrease and channels with a value of greater than 128 to increase.\n", "\n", "_Parameters_\n", "\n", "1. **Scale** How much contrast do we want to add to (or remove from) the image.\n", "\n", " Domain: [0, +inf]\n", " \n", "***On logit and sigmoid***\n", "\n", "Notice that for both transformations we first apply the logit to our tensor, then apply the transformation and finally take the sigmoid. This is important for two reasons. \n", "\n", "First, we don't want to overflow our tensor values. In other words, we need our final tensor values to be between [0,1]. Imagine, for instance, a tensor value at 0.99. We want to increase its brightness, but we can’t go over 1.0. By doing logit first, which first moves our space to -inf to +inf, this works fine. The same applies to contrast if we have a scale S > 1 (might make some of our tensor values greater than one).\n", "\n", "Second, when we apply contrast, we need to affect the dispersion of values around the middle value. Say we want to increase contrast. Then we need the bright values (>0.5) to get brighter and dark values (<0.5) to get darker. We must first transform our tensor values so our values which were originally <0.5 are now negative and our values which were originally >0.5 are now positive. This way, when we multiply by a constant, the dispersion around 0 will increase. The logit function does exactly this and allows us to increase or decrease dispersion around a mid value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def logit(x:Tensor)->Tensor: return -(1/x-1).log()\n", "def logit_(x:Tensor)->Tensor: return (x.reciprocal_().sub_(1)).log_().neg_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def contrast(x:Tensor, scale:float)->Tensor: return x.mul_(scale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "FlowField = Tensor\n", "LogitTensorImage = TensorImage\n", "AffineMatrix = Tensor\n", "KWArgs = Dict[str,Any]\n", "ArgStar = Collection[Any]\n", "TensorImageSize = Tuple[int,int,int]\n", "\n", "LightingFunc = Callable[[LogitTensorImage, ArgStar, KWArgs], LogitTensorImage]\n", "PixelFunc = Callable[[TensorImage, ArgStar, KWArgs], TensorImage]\n", "CoordFunc = Callable[[FlowField, TensorImageSize, ArgStar, KWArgs], LogitTensorImage]\n", "AffineFunc = Callable[[KWArgs], AffineMatrix]\n", "\n", "\n", "class ItemBase():\n", " \"All transformable dataset items use this type\"\n", " @property\n", " @abstractmethod\n", " def device(self): pass\n", " @property\n", " @abstractmethod\n", " def data(self): pass\n", "\n", "class ImageBase(ItemBase):\n", " \"Img based `Dataset` items derive from this. Subclass to handle lighting, pixel, etc\"\n", " def lighting(self, func:LightingFunc, *args, **kwargs)->'ImageBase': return self\n", " def pixel(self, func:PixelFunc, *args, **kwargs)->'ImageBase': return self\n", " def coord(self, func:CoordFunc, *args, **kwargs)->'ImageBase': return self\n", " def affine(self, func:AffineFunc, *args, **kwargs)->'ImageBase': return self\n", "\n", " def set_sample(self, **kwargs)->'ImageBase':\n", " \"Set parameters that control how we `grid_sample` the image after transforms are applied\"\n", " self.sample_kwargs = kwargs\n", " return self\n", " \n", " def clone(self)->'ImageBase': \n", " \"Clones this item and its `data`\"\n", " return self.__class__(self.data.clone())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Image(ImageBase):\n", " \"Supports appying transforms to image data\"\n", " def __init__(self, px)->'Image':\n", " \"create from raw tensor image data `px`\"\n", " self._px = px\n", " self._logit_px=None\n", " self._flow=None\n", " self._affine_mat=None\n", " self.sample_kwargs = {}\n", "\n", " @property\n", " def shape(self)->Tuple[int,int,int]: \n", " \"Returns (ch, h, w) for this image\"\n", " return self._px.shape\n", " @property\n", " def size(self)->Tuple[int,int]: \n", " \"Returns (h, w) for this image\"\n", " return self.shape[-2:]\n", " @property\n", " def device(self)->torch.device: return self._px.device\n", " \n", " def __repr__(self): return f'{self.__class__.__name__} ({self.shape})'\n", "\n", " def refresh(self)->None:\n", " \"Applies any logit or affine transfers that have been \"\n", " if self._logit_px is not None:\n", " self._px = self._logit_px.sigmoid_()\n", " self._logit_px = None\n", " if self._affine_mat is not None or self._flow is not None:\n", " self._px = grid_sample(self._px, self.flow, **self.sample_kwargs)\n", " self.sample_kwargs = {}\n", " self._flow = None\n", " return self\n", "\n", " @property\n", " def px(self)->TensorImage:\n", " \"Get the tensor pixel buffer\"\n", " self.refresh()\n", " return self._px\n", " @px.setter\n", " def px(self,v:TensorImage)->None: \n", " \"Set the pixel buffer to `v`\"\n", " self._px=v\n", "\n", " @property\n", " def flow(self)->FlowField:\n", " \"Access the flow-field grid after applying queued affine transforms\"\n", " if self._flow is None:\n", " self._flow = affine_grid(self.shape)\n", " if self._affine_mat is not None:\n", " self._flow = affine_mult(self._flow,self._affine_mat)\n", " self._affine_mat = None\n", " return self._flow\n", " \n", " @flow.setter\n", " def flow(self,v:FlowField): self._flow=v\n", "\n", " def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image':\n", " \"Equivalent to `image = sigmoid(func(logit(image)))`\"\n", " self.logit_px = func(self.logit_px, *args, **kwargs)\n", " return self\n", "\n", " def pixel(self, func:PixelFunc, *args, **kwargs)->'Image':\n", " \"Equivalent to `image.px = func(image.px)`\"\n", " self.px = func(self.px, *args, **kwargs)\n", " return self\n", "\n", " def coord(self, func:CoordFunc, *args, **kwargs)->'Image':\n", " \"Equivalent to `image.flow = func(image.flow, image.size)`\" \n", " self.flow = func(self.flow, self.shape, *args, **kwargs)\n", " return self\n", "\n", " def affine(self, func:AffineFunc, *args, **kwargs)->'Image':\n", " \"Equivalent to `image.affine_mat = image.affine_mat @ func()`\" \n", " m = tensor(func(*args, **kwargs)).to(self.device)\n", " self.affine_mat = self.affine_mat @ m\n", " return self\n", "\n", " def resize(self, size:Union[int,TensorImageSize])->'Image':\n", " \"Resize the image to `size`, size can be a single int\"\n", " assert self._flow is None\n", " if isinstance(size, int): size=(self.shape[0], size, size)\n", " self.flow = affine_grid(size)\n", " return self\n", "\n", " @property\n", " def affine_mat(self)->AffineMatrix:\n", " \"Get the affine matrix that will be applied by `refresh`\"\n", " if self._affine_mat is None:\n", " self._affine_mat = torch.eye(3).to(self.device)\n", " return self._affine_mat\n", " @affine_mat.setter\n", " def affine_mat(self,v)->None: self._affine_mat=v\n", "\n", " @property\n", " def logit_px(self)->LogitTensorImage:\n", " \"Get logit(image.px)\"\n", " if self._logit_px is None: self._logit_px = logit_(self.px)\n", " return self._logit_px\n", " @logit_px.setter\n", " def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v\n", " \n", " def show(self, ax:plt.Axes=None, **kwargs:Any)->None: \n", " \"Plots the image into `ax`\"\n", " show_image(self.px, ax=ax, **kwargs)\n", " \n", " @property\n", " def data(self)->TensorImage: \n", " \"Returns this images pixels as a tensor\"\n", " return self.px" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ImageDataset.from_folder(PATH/'train')\n", "valid_ds = ImageDataset.from_folder(PATH/'test')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = lambda: train_ds[1][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = x()\n", "img.logit_px = contrast(img.logit_px, 0.5)\n", "img.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().lighting(contrast, 0.5).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transform class" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Transform():\n", " _wrap=None\n", " def __init__(self, func): self.func=func\n", " def __call__(self, x, *args, **kwargs):\n", " if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)\n", " else: return self.func(x, *args, **kwargs)\n", " \n", "class TfmLighting(Transform): _wrap='lighting'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@TfmLighting\n", "def brightness(x, change): return x.add_(scipy.special.logit(change))\n", "@TfmLighting\n", "def contrast(x, scale): return x.mul_(scale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "\n", "x().show(axes[0])\n", "contrast(x(), 1.0).show(axes[1])\n", "contrast(x(), 0.5).show(axes[2])\n", "contrast(x(), 2.0).show(axes[3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "\n", "x().show(axes[0])\n", "brightness(x(), 0.8).show(axes[1])\n", "brightness(x(), 0.5).show(axes[2])\n", "brightness(x(), 0.2).show(axes[3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def brightness_contrast(x, scale_contrast, change_brightness):\n", " return brightness(contrast(x, scale=scale_contrast), change=change_brightness)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "\n", "brightness_contrast(x(), 0.75, 0.7).show(axes[0])\n", "brightness_contrast(x(), 2.0, 0.3).show(axes[1])\n", "brightness_contrast(x(), 2.0, 0.7).show(axes[2])\n", "brightness_contrast(x(), 0.75, 0.3).show(axes[3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random lighting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will make our previous transforms random since we are interested in automatizing the pipeline. We will achieve this by making our parameters stochastic with a specific distribution. \n", "\n", "We will use a uniform distribution for brightness change since its domain is the real numbers and the impact varies linearly with the scale. For contrast change we use [log_uniform](https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php) for two reasons. First, contrast scale has a domain of [0, inf]. Second, the impact of the scale in the transformation is non-linear (i.e. 0.5 is as extreme as 2.0, 0.2 is as extreme as 5). The log_uniform function is appropriate because it has the same domain and correctly represents the non-linearity of the transform, P(0.5) = P(2)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "FloatOrTensor = Union[float,Tensor]\n", "BoolOrTensor = Union[bool,Tensor]\n", "def uniform(low:Number, high:Number, size:List[int]=None)->FloatOrTensor:\n", " \"Draw 1 or shape=`size` random floats from uniform dist: min=`low`, max=`high`\"\n", " return random.uniform(low,high) if size is None else torch.FloatTensor(*listify(size)).uniform_(low,high)\n", "\n", "def log_uniform(low, high, size=None)->FloatOrTensor:\n", " \"Draw 1 or shape=`size` random floats from uniform dist: min=log(`low`), max=log(`high`)\"\n", " res = uniform(log(low), log(high), size)\n", " return exp(res) if size is None else res.exp_()\n", "\n", "def rand_bool(p:float, size=None)->BoolOrTensor: \n", " \"Draw 1 or shape=`size` random booleans (True occuring probability p)\"\n", " return uniform(0,1,size)None:\n", " \"Create a transform for `func` and assign it an priority `order`, attach to Image class\"\n", " if order is not None: self.order=order\n", " self.func=func\n", " functools.update_wrapper(self, self.func)\n", " self.func.__annotations__['return'] = Image\n", " self.params = copy(func.__annotations__)\n", " self.def_args = get_default_args(func)\n", " setattr(Image, func.__name__,\n", " lambda x, *args, **kwargs: self.calc(x, *args, **kwargs))\n", " \n", " def __call__(self, *args:Any, p:float=1., is_random:bool=True, **kwargs:Any)->Image:\n", " \"Calc now if `args` passed; else create a transform called prob `p` if `random`\"\n", " if args: return self.calc(*args, **kwargs)\n", " else: return RandTransform(self, kwargs=kwargs, is_random=is_random, p=p)\n", " \n", " def calc(self, x:Image, *args:Any, **kwargs:Any)->Image:\n", " \"Apply this transform to image `x`, wrapping it if necessary\"\n", " if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)\n", " else: return self.func(x, *args, **kwargs)\n", "\n", " @property\n", " def name(self)->str: return self.__class__.__name__\n", " \n", " def __repr__(self)->str: return f'{self.name} ({self.func.__name__})'\n", "\n", "class TfmLighting(Transform): order,_wrap = 8,'lighting'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class RandTransform():\n", " \"Wraps `Transform` to add randomized execution\"\n", " tfm:Transform\n", " kwargs:dict\n", " p:int=1.0\n", " resolved:dict = field(default_factory=dict)\n", " do_run:bool = True\n", " is_random:bool = True\n", " def __post_init__(self): functools.update_wrapper(self, self.tfm)\n", " \n", " def resolve(self)->None:\n", " \"Bind any random variables needed tfm calc\"\n", " if not self.is_random:\n", " self.resolved = {**self.tfm.def_args, **self.kwargs}\n", " return\n", "\n", " self.resolved = {}\n", " # for each param passed to tfm...\n", " for k,v in self.kwargs.items():\n", " # ...if it's annotated, call that fn...\n", " if k in self.tfm.params:\n", " rand_func = self.tfm.params[k]\n", " self.resolved[k] = rand_func(*listify(v))\n", " # ...otherwise use the value directly\n", " else: self.resolved[k] = v\n", " # use defaults for any args not filled in yet\n", " for k,v in self.tfm.def_args.items():\n", " if k not in self.resolved: self.resolved[k]=v\n", " # anything left over must be callable without params\n", " for k,v in self.tfm.params.items():\n", " if k not in self.resolved and k!='return': self.resolved[k]=v()\n", "\n", " self.do_run = rand_bool(self.p)\n", "\n", " @property\n", " def order(self)->int: return self.tfm.order\n", "\n", " def __call__(self, x:Image, *args, **kwargs)->Image:\n", " \"Randomly execute our tfm on `x`\"\n", " return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@TfmLighting\n", "def brightness(x, change:uniform): \n", " \"`change` brightness of image `x`\"\n", " return x.add_(scipy.special.logit(change))\n", "\n", "@TfmLighting\n", "def contrast(x, scale:log_uniform): \n", " \"`scale` contrast of image `x`\"\n", " return x.mul_(scale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().contrast(scale=2).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x().contrast(scale=2).brightness(0.8).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = contrast(scale=(0.3,3))\n", "tfm.resolve()\n", "tfm,tfm.resolved,tfm.do_run" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# all the same\n", "tfm.resolve()\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: tfm(x()).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = contrast(scale=(0.3,3))\n", "\n", "# different\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes:\n", " tfm.resolve()\n", " tfm(x()).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = contrast(scale=2, is_random=False)\n", "tfm.resolve()\n", "tfm(x()).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Composition" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are interested in composing the transform functions so as to apply them all at once. We will try to feed a list of transforms to our pipeline for it to apply all of them.\n", "\n", "Applying a function to our transforms before calling them in Python is easiest if we use a decorator. You can find more about decorators [here](https://www.thecodeship.com/patterns/guide-to-python-function-decorators/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "TfmList=Union[Transform, Collection[Transform]]\n", "def resolve_tfms(tfms:TfmList):\n", " \"Resolve every tfm in `tfms`\"\n", " for f in listify(tfms): f.resolve()\n", "\n", "def apply_tfms(tfms:TfmList, x:Image, do_resolve:bool=True):\n", " \"Apply all the `tfms` to `x`, if `do_resolve` refresh all the random args\"\n", " if not tfms: return x\n", " tfms = listify(tfms)\n", " if do_resolve: resolve_tfms(tfms)\n", " x = x.clone()\n", " for tfm in tfms: x = tfm(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = train_ds[1][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [contrast(scale=(0.3,3.0), p=0.9),\n", " brightness(change=(0.35,0.65), p=0.9)]\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfms,x).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(2,4, figsize=(12,6))\n", "for i in range(4):\n", " apply_tfms(tfms,x).show(axes[0,i])\n", " apply_tfms(tfms,x,do_resolve=False).show(axes[1,i])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "apply_tfms([],x).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DatasetTfm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DatasetTfm(Dataset):\n", " \"A `Dataset` that applies a list of transforms to every item drawn\"\n", " def __init__(self, ds:Dataset, tfms:TfmList=None, **kwargs:Any):\n", " \"this dataset will apply `tfms` to `ds`\"\n", " self.ds,self.tfms,self.kwargs = ds,tfms,kwargs\n", " \n", " def __len__(self)->int: return len(self.ds)\n", " \n", " def __getitem__(self,idx:int)->Tuple[Image,Any]:\n", " \"returns tfms(x),y\"\n", " x,y = self.ds[idx]\n", " return apply_tfms(self.tfms, x, **self.kwargs), y\n", " \n", " def __getattr__(self,k): \n", " \"passthrough access to wrapped dataset attributes\"\n", " return getattr(self.ds, k)\n", "\n", "import nb_001b\n", "nb_001b.DatasetTfm = DatasetTfm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "ItemsList = Collection[Union[Tensor,ItemBase,'ItemsList',float,int]]\n", "def to_data(b:ItemsList):\n", " \"Recursively maps lists of items to their wrapped data\"\n", " if is_listy(b): return [to_data(o) for o in b]\n", " return b.data if isinstance(b,ItemBase) else b\n", "\n", "def data_collate(batch:ItemsList)->Tensor:\n", " \"Convert `batch` items to tensor data\"\n", " return torch.utils.data.dataloader.default_collate(to_data(batch))\n", "\n", "@dataclass\n", "class DeviceDataLoader():\n", " \"DataLoader that ensures items in each batch are tensor on specified device\"\n", " dl: DataLoader\n", " device: torch.device\n", " def __post_init__(self)->None: self.dl.collate_fn=data_collate\n", "\n", " def __len__(self)->int: return len(self.dl)\n", " def __getattr__(self,k:str)->Any: return getattr(self.dl, k)\n", " def proc_batch(self,b:ItemsList)->Tensor: return to_device(b, self.device)\n", "\n", " def __iter__(self):\n", " self.gen = map(self.proc_batch, self.dl)\n", " return iter(self.gen)\n", "\n", " @classmethod\n", " def create(cls, *args, device=default_device, **kwargs)->'DeviceDataLoader':\n", " \"Creates `DataLoader` and make sure its data is always on `device`\"\n", " return cls(DataLoader(*args, **kwargs), device=device)\n", " \n", "nb_001b.DeviceDataLoader = DeviceDataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch.create(train_ds, valid_ds, bs=bs, num_workers=4)\n", "len(data.train_dl), len(data.valid_dl), data.train_dl.dataset.c" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def show_image_batch(dl:DataLoader, classes:Collection[str], \n", " rows:Optional[int]=None, figsize:Tuple[int,int]=(12,15))->None:\n", " \"Show a batch of images from `dl` titled according to `classes`\"\n", " x,y = next(iter(dl))\n", " if rows is None: rows = int(math.sqrt(len(x)))\n", " show_images(x[:rows*rows],y[:rows*rows],rows, classes)\n", "\n", "def show_images(x:Collection[Image],y:int,rows:int, classes:Collection[str], figsize:Tuple[int,int]=(9,9))->None:\n", " \"Plot images (`x[i]`) from `x` titled according to classes[y[i]]\"\n", " fig, axs = plt.subplots(rows,rows,figsize=figsize)\n", " for i, ax in enumerate(axs.flatten()):\n", " show_image(x[i], ax)\n", " ax.set_title(classes[y[i]])\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_image_batch(data.train_dl, train_ds.classes, 6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_image_batch(data.train_dl, train_ds.classes, 6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Affine" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now add affine transforms that operate on the coordinates instead of pixels like the lighting transforms we just saw. An [affine transformation](https://en.wikipedia.org/wiki/Affine_transformation) is a function \"(...) between affine spaces which preserves points, straight lines and planes.\" " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Details" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our implementation first creates a grid of coordinates for the original image. The grid is normalized to a [-1, 1] range with (-1, -1) representing the top left corner, (1, 1) the bottom right corner and (0, 0) the center. Next, we build an affine matrix representing our desired transform and we multiply it by our original grid coordinates. The result will be a set of x, y coordinates which references where in the input image will each of the pixels in the output image be mapped. It has a size of w \\* h \\* 2 since it needs two coordinates for each of the h * w pixels of the output image. \n", "\n", "This is clearest if we see it graphically. We will build an affine matrix of the following form:\n", "\n", "`[[a, b, e],\n", " [c, d, f]]`\n", "\n", "\n", "with which we will transform each pair of x, y coordinates in our original grid into our transformation grid:\n", "\n", "\n", "`[[a, b], [[x], [[e], [[x'],\n", " [c, d]] x [y]] + [f]] = [y']]` \n", "\n", "So after the transform we will get a new grid with which to map our input image into our output image. This will be our **map of where from exactly does our transformation source each pixel in the output image**.\n", "\n", "**Enter problems**\n", "\n", "Affine transforms face two problems that must be solved independently:\n", "1. **The interpolation problem**: The result of our transformation gives us float coordinates, and we need to decide, for each (i,j), how to assign these coordinates to pixels in the input image.\n", "2. **The missing pixel problem**: The result of our transformation may have coordinates which exceed the [-1, 1] range of our original grid and thus fall outside of our original grid.\n", "\n", "**Solutions to problems**\n", "\n", "1. **The interpolation problem**: We will perform a [bilinear interpolation](https://en.wikipedia.org/wiki/Bilinear_interpolation). This takes an average of the values of the pixels corresponding to the four points in the grid surrounding the result of our transformation, with weights depending on how close we are to each of those points. \n", "2. **The missing pixel problem**: For these values we need padding, and we face a few options:\n", "\n", " 1. Adding zeros on the side (so the pixels that fall out will be black)\n", " 2. Replacing them by the value at the border\n", " 3. Mirroring the content of the picture on the other side (reflect padding).\n", " \n", " \n", "### Transformation Method\n", "\n", "**Zoom**\n", "\n", "Zoom changes the focus of the image according to a scale. If a scale of >1 is applied, grid pixels will be mapped to coordinates that are more central than the pixel's coordinates (closer to 0,0) while if a scale of <1 is applied, grid pixels will be mapped to more perispheric coordinates (closer to the borders) in the input image.\n", "\n", "We can also translate our transform to zoom into a non-centrical area of the image. For this we use $col_c$ which displaces the x axis and $row_c$ which displaces the y axis.\n", "\n", "_Parameters_\n", "\n", "1. **Scale** How much do we want to zoom in or out to our image.\n", "\n", " Domain: Real numbers\n", " \n", "2. **Col_pct** How much do we want to displace our zoom along the x axis.\n", "\n", " Domain: Real numbers between 0 and 1\n", " \n", " \n", "3. **Row_pct** How much do we want to displace our zoom along the y axis.\n", "\n", " Domain: Real numbers between 0 and 1\n", " \n", "\n", "Affine matrix\n", "\n", "`[[1/scale, 0, col_c],\n", " [0, 1/scale, row_c]]`\n", "\n", "\n", "**Rotate**\n", "\n", "Rotate shifts the image around its center in a given angle theta. The rotation is counterclockwise if theta is positive and clockwise if theta is negative. If you are curious about the derivation of the rotation matrix you can find it [here](https://matthew-brett.github.io/teaching/rotation_2d.html).\n", "\n", "_Parameters_\n", "\n", "1. **Degrees** By which angle do we want to rotate our image.\n", "\n", " Domain: Real numbers\n", " \n", "Affine matrix\n", "\n", "`[[cos(theta), -sin(theta), 0],\n", " [sin(theta), cos(theta), 0]]`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deterministic affine" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def grid_sample_nearest(input:TensorImage, coords:FlowField, padding_mode:str='zeros')->TensorImage:\n", " \"Grab pixels in `coords` from `input`. sample with nearest neighbor mode, pad with zeros by default\"\n", " if padding_mode=='border': coords.clamp(-1,1)\n", " bs,ch,h,w = input.size()\n", " sz = tensor([w,h]).float()[None,None]\n", " coords.add_(1).mul_(sz/2)\n", " coords = coords[0].round_().long()\n", " if padding_mode=='zeros':\n", " mask = (coords[...,0] < 0) + (coords[...,1] < 0) + (coords[...,0] >= w) + (coords[...,1] >= h)\n", " mask.clamp_(0,1)\n", " coords[...,0].clamp_(0,w-1)\n", " coords[...,1].clamp_(0,h-1)\n", " result = input[...,coords[...,1],coords[...,0]]\n", " if padding_mode=='zeros': result[...,mask] = result[...,mask].zero_()\n", " return result" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflect')->TensorImage:\n", " \"Grab pixels in `coords` from `input` sampling by `mode`. pad is reflect or zeros.\"\n", " if padding_mode=='reflect': padding_mode='reflection'\n", " #if mode=='nearest': return grid_sample_nearest(x[None], coords, padding_mode)[0]\n", " return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]\n", "\n", "def affine_grid(size:TensorImageSize)->FlowField:\n", " size = ((1,)+size)\n", " N, C, H, W = size\n", " grid = FloatTensor(N, H, W, 2)\n", " linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])\n", " grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])\n", " linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])\n", " grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])\n", " return grid\n", "\n", "def affine_mult(c:FlowField, m:AffineMatrix)->FlowField:\n", " if m is None: return c\n", " size = c.size()\n", " c = c.view(-1,2)\n", " c = torch.addmm(m[:2,2], c, m[:2,:2].t()) \n", " return c.view(size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def rotate(degrees):\n", " angle = degrees * math.pi / 180\n", " return [[cos(angle), -sin(angle), 0.],\n", " [sin(angle), cos(angle), 0.],\n", " [0. , 0. , 1.]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def xi(): return train_ds[1][0]\n", "x = xi().data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c = affine_grid(x.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = rotate(30)\n", "m = x.new_tensor(m)\n", "m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c[0,...,0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c[0,...,1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c = affine_mult(c,m)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c[0,...,0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c[0,...,1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img2 = grid_sample(x, c, padding_mode='zeros')\n", "show_image(img2);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xi().affine(rotate, 30).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Affine transform" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TfmAffine(Transform): \n", " \"Wraps affine tfm funcs\"\n", " order,_wrap = 5,'affine'\n", "class TfmPixel(Transform): \n", " \"Wraps pixel tfm funcs\"\n", " order,_wrap = 10,'pixel'\n", "\n", "@TfmAffine\n", "def rotate(degrees:uniform):\n", " \"Affine func that rotates the image\"\n", " angle = degrees * math.pi / 180\n", " return [[cos(angle), -sin(angle), 0.],\n", " [sin(angle), cos(angle), 0.],\n", " [0. , 0. , 1.]]\n", "\n", "def get_zoom_mat(sw:float, sh:float, c:float, r:float)->AffineMatrix:\n", " \"`sw`,`sh` scale width,height - `c`,`r` focus col,row\"\n", " return [[sw, 0, c],\n", " [0, sh, r],\n", " [0, 0, 1.]]\n", "\n", "@TfmAffine\n", "def zoom(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5):\n", " \"Zoom image by `scale`. `row_pct`,`col_pct` select focal point of zoom\"\n", " s = 1-1/scale\n", " col_c = s * (2*col_pct - 1)\n", " row_c = s * (2*row_pct - 1)\n", " return get_zoom_mat(1/scale, 1/scale, col_c, row_c)\n", "\n", "@TfmAffine\n", "def squish(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5):\n", " \"Squish image by `scale`. `row_pct`,`col_pct` select focal point of zoom\"\n", " if scale <= 1: \n", " col_c = (1-scale) * (2*col_pct - 1)\n", " return get_zoom_mat(scale, 1, col_c, 0.)\n", " else: \n", " row_c = (1-1/scale) * (2*row_pct - 1)\n", " return get_zoom_mat(1, 1/scale, 0., row_c)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rotate(xi(), 30).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "zoom(xi(), 0.6).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "zoom(xi(), 0.6).set_sample(padding_mode='zeros').show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "zoom(xi(), 2, 0.2, 0.2).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scales = [0.75,0.9,1.1,1.33]\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for i, ax in enumerate(axes): squish(xi(), scales[i]).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img2 = rotate(xi(), 30).refresh()\n", "img2 = zoom(img2, 1.6)\n", "_,axes=plt.subplots(1,3,figsize=(9,3))\n", "xi().show(axes[0])\n", "img2.show(axes[1])\n", "zoom(rotate(xi(), 30), 1.6).show(axes[2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xi().show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xi().resize(48).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img2 = zoom(xi().resize(48), 1.6, 0.8, 0.2)\n", "rotate(img2, 30).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img2 = zoom(xi().resize(24), 1.6, 0.8, 0.2)\n", "rotate(img2, 30).show(hide_axis=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img2 = zoom(xi().resize(48), 1.6, 0.8, 0.2)\n", "rotate(img2, 30).set_sample(mode='nearest').show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random affine" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we did with the Lighting transform, we now want to build randomness into our pipeline so we can increase the automatization of the transform process. \n", "\n", "We will use a uniform distribution for both our transforms since their impact is linear and their domain is the real numbers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Apply all transforms**\n", "\n", "We will make all transforms try to do as little calculations as possible.\n", "\n", "We do only one affine transformation by multiplying all the affine matrices of the transforms, then we apply to the coords any non-affine transformation we might want (jitter, elastic distorsion). Next, we crop the coordinates we want to keep and, by doing it before the interpolation, we don't need to compute pixel values that won't be used afterwards. Finally we perform the interpolation and we apply all the transforms that operate pixelwise (brightness, contrast)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = rotate(degrees=(-45,45.), p=0.75); tfm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm.resolve(); tfm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = xi()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfm, x).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [rotate(degrees=(-45,45.), p=0.75),\n", " zoom(scale=(0.5,2.0), p=0.75)]\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfms,x).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def apply_tfms(tfms:TfmList, x:TensorImage, do_resolve:bool=True, \n", " xtra:Optional[Dict[Transform,dict]]=None, size:TensorImageSize=None, **kwargs:Any)->TensorImage:\n", " \"Apply `tfms` to x, resize to `size`. `do_resolve` rebind random params. `xtra` custom args for a tfm\"\n", " if not (tfms or size): return x\n", " if not xtra: xtra={}\n", " tfms = sorted(listify(tfms), key=lambda o: o.tfm.order)\n", " if do_resolve: resolve_tfms(tfms)\n", " x = x.clone()\n", " if kwargs: x.set_sample(**kwargs)\n", " if size: x.resize(size)\n", " for tfm in tfms:\n", " if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])\n", " else: x = tfm(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [rotate(degrees=(-45,45.), p=0.75),\n", " zoom(scale=(1.0,2.0), row_pct=(0,1.), col_pct=(0,1.))]\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfms,x, padding_mode='zeros', size=64).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [squish(scale=(0.5,2), row_pct=(0,1.), col_pct=(0,1.))]\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfms,x).show(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Coord and pixel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Jitter / flip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last two transforms we will use are **jitter** and **flip**. \n", "\n", "**Jitter**\n", "\n", "Jitter is a transform which adds a random value to each of the pixels to make them somewhat different than the original ones. In our implementation we first get a random number between (-1, 1) and we multiply it by a constant M which scales it.\n", "\n", "_Parameters_\n", "\n", "1. **Magnitude** How much random noise do we want to add to each of the pixels in our image.\n", "\n", " Domain: Real numbers between 0 and 1.\n", " \n", "**Flip**\n", "\n", "Flip is a transform that reflects the image on a given axis.\n", "\n", "_Parameters_\n", "\n", "1. **P** Probability of applying the transformation to an input.\n", "\n", " Domain: Real numbers between 0 and 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TfmCoord(Transform): order,_wrap = 4,'coord'\n", "\n", "@TfmCoord\n", "def jitter(c, size, magnitude:uniform):\n", " return c.add_((torch.rand_like(c)-0.5)*magnitude*2)\n", "\n", "@TfmPixel\n", "def flip_lr(x): return x.flip(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = jitter(magnitude=(0,0.1))\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes:\n", " tfm.resolve()\n", " tfm(xi()).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = flip_lr(p=0.5)\n", "\n", "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes:\n", " tfm.resolve()\n", " tfm(xi()).show(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Crop/pad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Crop**\n", "\n", "Crop is a transform that cuts a series of pixels from an image. It does this by removing rows and columns from the input image.\n", "\n", "_Parameters_\n", "\n", "1. **Size** What is the target size of each side in pixels. If only one number *s* is specified, image is made square with dimensions *s* \\* *s*.\n", "\n", " Domain: Positive integers.\n", " \n", "2. **Row_pct** Determines where to cut our image vertically on the bottom and top (which rows are left out). If <0.5, more rows will be cut in the top than in the bottom and viceversa (varies linearly).\n", "\n", " Domain: Real numbers between 0 and 1.\n", " \n", "3. **Col_pct** Determines where to cut our image horizontally on the left and right (which columns are left out). If <0.5, more rows will be cut in the left than in the right and viceversa (varies linearly).\n", "\n", " Domain: Real numbers between 0 and 1.\n", " \n", "Our three parameters are related with the following equations:\n", "\n", "1. output_rows = [**row_pct***(input_rows-**size**):**size**+**row_pct***(input_rows-**size**)]\n", "\n", "2. output_cols = [**col_pct***(input_cols-**size**):**size**+**col_pct***(input_cols-**size**)]\n", "\n", "**Pad**\n", "\n", "\n", "Pads each of the four borders of our image with a certain amount of pixels. Can pad with reflection (reflects border pixels to fill new pixels) or zero (adds black pixels). \n", "\n", "_Parameters_\n", "\n", "1. **Padding** Amount of pixels to add to each border. [More details](https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad)\n", "\n", " Domain: Positive integers.\n", " \n", "2. **Mode** How to fill new pixels. For more detail see the Pytorch subfunctions for padding.\n", "\n", " Domain: \n", " - Reflect (default): reflects opposite pixels to fill new pixels. [More details](https://pytorch.org/docs/stable/nn.html#torch.nn.ReflectionPad2d)\n", " - Constant: adds pixels with specified value (default is 0, black pixels) [More details](https://pytorch.org/docs/stable/nn.html#torch.nn.ConstantPad2d)\n", " - Replicate: replicates border row or column pixels to fill new pixels [More details](https://pytorch.org/docs/stable/nn.html#torch.nn.ReplicationPad2d)\n", " \n", " \n", "***On using padding and crop***\n", "\n", "A nice way to use these two functions is to combine them into one transform. We can add padding to the image and then crop some of it out. This way, we can create a new image to augment our training set without losing image information by cropping. Furthermore, this can be done in several ways (modifying the amount and type of padding and the crop style) so it gives us great flexibility to add images to our training set. You can find an example of this in the code below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[(o.__name__,o.order) for o in\n", " sorted((Transform,TfmAffine,TfmCoord,TfmLighting,TfmPixel),key=attrgetter('order'))]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@partial(TfmPixel, order=-10)\n", "def pad(x, padding, mode='reflect'):\n", " \"Pad `x` with `padding` pixels. `mode` fills in space ('reflect','zeros',etc)\"\n", " return F.pad(x[None], (padding,)*4, mode=mode)[0]\n", "\n", "@TfmPixel\n", "def crop(x, size, row_pct:uniform=0.5, col_pct:uniform=0.5):\n", " \"Crop `x` to `size` pixels. `row_pct`,`col_pct` select focal point of crop\"\n", " size = listify(size,2)\n", " rows,cols = size\n", " row = int((x.size(1)-rows+1) * row_pct)\n", " col = int((x.size(2)-cols+1) * col_pct)\n", " return x[:, row:row+rows, col:col+cols].contiguous()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pad(xi(), 4, 'constant').show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "crop(pad(xi(), 4, 'constant'), 32, 0.25, 0.75).show(hide_axis=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "crop(pad(xi(), 4), 32, 0.25, 0.75).show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combine" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [flip_lr(p=0.5),\n", " pad(padding=4, mode='constant'),\n", " crop(size=32, row_pct=(0,1.), col_pct=(0,1.))]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfms, x).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [\n", " flip_lr(p=0.5),\n", " contrast(scale=(0.5,2.0)),\n", " brightness(change=(0.3,0.7)),\n", " rotate(degrees=(-45,45.), p=0.5),\n", " zoom(scale=(0.5,1.2), p=0.8)\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(1,4, figsize=(12,3))\n", "for ax in axes: apply_tfms(tfms, x).show(ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(2,4, figsize=(12,6))\n", "\n", "for i in range(4):\n", " apply_tfms(tfms, x, padding_mode='zeros', size=48).show(axes[0][i], hide_axis=False)\n", " apply_tfms(tfms, x, mode='nearest', do_resolve=False).show(axes[1][i], hide_axis=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RandomResizedCrop (Torchvision version)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def compute_zs_mat(sz:TensorImageSize, scale:float, squish:float, \n", " invert:bool, row_pct:float, col_pct:float)->AffineMatrix:\n", " \"Utility routine to compute zoom/squish matrix\"\n", " orig_ratio = math.sqrt(sz[2]/sz[1])\n", " for s,r,i in zip(scale,squish, invert):\n", " s,r = math.sqrt(s),math.sqrt(r)\n", " if s * r <= 1 and s / r <= 1: #Test if we are completely inside the picture\n", " w,h = (s/r, s*r) if i else (s*r,s/r)\n", " w /= orig_ratio\n", " h *= orig_ratio\n", " col_c = (1-w) * (2*col_pct - 1)\n", " row_c = (1-h) * (2*row_pct - 1)\n", " return get_zoom_mat(w, h, col_c, row_c)\n", " \n", " #Fallback, hack to emulate a center crop without cropping anything yet.\n", " if orig_ratio > 1: return get_zoom_mat(1/orig_ratio**2, 1, 0, 0.)\n", " else: return get_zoom_mat(1, orig_ratio**2, 0, 0.)\n", "\n", "@TfmCoord\n", "def zoom_squish(c, size, scale:uniform=1.0, squish:uniform=1.0, invert:rand_bool=False, \n", " row_pct:uniform=0.5, col_pct:uniform=0.5):\n", " #This is intended for scale, squish and invert to be of size 10 (or whatever) so that the transform\n", " #can try a few zoom/squishes before falling back to center crop (like torchvision.RandomResizedCrop)\n", " m = compute_zs_mat(size, scale, squish, invert, row_pct, col_pct)\n", " return affine_mult(c, FloatTensor(m))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rrc = zoom_squish(scale=(0.25,1.0,10), squish=(0.5,1.0,10), invert=(0.5,10),\n", " row_pct=(0,1.), col_pct=(0,1.))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_,axes = plt.subplots(2,4, figsize=(12,6))\n", "for i in range(4):\n", " apply_tfms(rrc, x, size=48).show(axes[0][i])\n", " apply_tfms(rrc, x, do_resolve=False, mode='nearest').show(axes[1][i])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }