{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from fastai.datasets import URLs, untar_data\n", "from pathlib import Path\n", "import pandas as pd, re, PIL, os, mimetypes, csv, itertools\n", "import matplotlib.pyplot as plt\n", "from collections import OrderedDict\n", "from enum import Enum\n", "from warnings import warn\n", "from functools import partial,reduce\n", "from PIL import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data source API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the base class to define a transform when we want something more complex than a function. The `_order` helps sort the transforms before applying them. `setup` is a preparation step to get the state ready on the `data` (which is a `DataSource`). `__call__` is the main function that applies the transform to `o` and `decode` does the reverse operation.\n", "\n", "**NB: You should only implement `decode` if your transform needs to be reversed for display purposes.** For instance we want to reverse the operation *class to index*, but we don't want to reverse the operation *open this image*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Transform():\n", " _order,filt = 0,None\n", "\n", " def __init__(self, encodes=None, decodes=None, filt=None, order=None):\n", " self.filt = filt\n", " if encodes is not None: self.encodes = encodes\n", " if decodes is not None: self.decodes = decodes\n", " if order is not None: self._order=order\n", "\n", " @classmethod\n", " def create(cls, f, filt=None): return f if hasattr(f,'setup') or isinstance(f,Transform) else cls(f)\n", " \n", " def __call__(self, o, filt=None, **kwargs): \n", " if self.filt is not None and self.filt!=filt: return o\n", " return self.encodes(o, **kwargs)\n", "\n", " def decode(self, o, filt=None, **kwargs): \n", " if self.filt is not None and self.filt!=filt: return o\n", " return self.decodes(o, **kwargs)\n", " \n", " def __repr__(self): return str(self.encodes) if self.__class__==Transform else str(self.__class__)\n", " def decodes(self, o, *args, **kwargs): return o\n", "\n", "def order_sorted(funcs, order='_order'):\n", " \"Listify `funcs` and sort with `order`.\"\n", " key = lambda o: getattr(o, order, 0)\n", " return sorted(listify(funcs), key=key)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def opt_call(f, fname, *args, **kwargs): return getattr(f,fname,noop)(*args, **kwargs)\n", "\n", "class Transforms():\n", " def __init__(self, tfms, order='_order'):\n", " self.order,self.tfms = order,[]\n", " self._tfms = [Transform.create(t) for t in listify(tfms)]\n", "\n", " def __call__(self, x, **kwargs): return self._apply(x, **kwargs)\n", " def decode(self, x, **kwargs): return self._apply(x, rev=True, fname='decode', **kwargs)\n", " def _apply(self, x, rev=False, fname='__call__', **kwargs):\n", " tfms = reversed(self.tfms) if rev else self.tfms\n", " for f in tfms: x = opt_call(f, fname, x, **kwargs)\n", " return x\n", "\n", " def __repr__(self): return str(self.tfms)\n", " def delete(self, idx): del(self.tfms[idx])\n", " def remove(self, tfm): self.tfms.remove(tfm)\n", "\n", " def setup(self, items=None): self.add(self._tfms, items)\n", " def add(self, tfm, items):\n", " # We only add one at a time so that each setup has access to correct tfm subset\n", " for t in order_sorted(tfm):\n", " self.tfms.append(t)\n", " opt_call(t, 'setup', items)\n", " \n", " def __getattr__(self, k):\n", " for t in reversed(self.tfms):\n", " a = getattr(t, k, None)\n", " if a is not None: return a\n", " raise AttributeError(k)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`DataSource` is the base class of the data blok API and is defined from `items`, `tfms` and `filts`. It can represent multiple datasets (train, valid, or more) that are contained in the `items`: each element of `filts` is a boolean mask or a collection of ints that says which `items` are in which dataset.\n", "\n", "When accessing an element, `tfms` are applied to it with optional `tfm_kwargs` passed along. Those kwargs are filtered so that each `tfms` only gets the one it accepts. At its base a `tfm` is just a simple function (open an image, resize it, one-hot encode a category, etc.) but it can be more complex (see `Transform` class below). Some transforms need a setup (for instance the transform that changes a cateogry to its index needs to compute all the classes) and some can be reversible for display purposes (if you change a category to an index, you still want to display the category name later on, or if you normalize your image, you need to undo that to display it). `DataSource` calls the potential `setup` function of its `Transform` at initialization and it has a `decode` method that will reverse the transforms that can be reversed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def coll_repr(c, max=1000):\n", " return f'(#{len(c)}) [' + ','.join(itertools.islice(map(str,c), 10)) + ('...'\n", " if len(c)>10 else '') + ']\\n'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class DataSource():\n", " def __init__(self, items, tfms=noop, tfm=None, filts=None):\n", " if filts is None: filts = [range_of(items)]\n", " ft = mask2idxs if isinstance(filts[0][0], bool) else listify\n", " self.filts = listify(ft(filt) for filt in filts)\n", " self.items,self.tfm = listify(items),ifnone(tfm, Transforms(tfms))\n", " self.tfm.setup(self)\n", "\n", " def __len__(self): return len(self.filts)\n", " def len(self, filt=0): return len(self.filts[filt])\n", " def __getitem__(self, i): return FilteredList(self, i)\n", " def decode(self, o, filt=0, **kwargs): return self.tfm.decode(o, filt=filt, **kwargs)\n", " def decoded(self, idx, filt=0): return self.decode(self.get(idx,filt), filt)\n", " def __iter__(self): return (self[i] for i in range_of(self))\n", " \n", " def get(self, idx, filt=0):\n", " if hasattr(idx,'__len__') and getattr(idx,'ndim',1):\n", " # rank>0 collection\n", " if isinstance(idx[0],bool): idx = mask2idxs(idx)\n", " return [self.get(i,filt) for i in idx] # index list\n", " it = self.items[self.filts[filt][idx]]\n", " return self.tfm(it, filt=filt)\n", "\n", " def __eq__(self,b):\n", " if not isinstance(b,DataSource): b = DataSource(b)\n", " return len(b) == len(self) and all(o==p for o,p in zip(self,b))\n", "\n", " def __repr__(self):\n", " res = f'{self.__class__.__name__}\\n'\n", " for i,o in enumerate(self): res += f'{i}: {coll_repr(o)}'\n", " return res\n", " \n", " @property\n", " def train(self): return self[0]\n", " @property\n", " def valid(self): return self[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class FilteredList:\n", " def __init__(self, dsrc, filt): self.dsrc,self.filt = dsrc,filt\n", " def __getitem__(self,i): return self.dsrc.get(i,self.filt)\n", " def decode(self, o): return self.dsrc.decode(o, self.filt)\n", " def __len__(self): return self.dsrc.len(self.filt)\n", " def __eq__(self,b): return len(b) == len(self) and all(o==p for o,p in zip(self,b))\n", " def __iter__(self): return (self[i] for i in range_of(self))\n", " def __repr__(self): return coll_repr(self)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Tests" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#Indexing\n", "dsrc = DataSource(range(5))\n", "test_eq(dsrc,[0,1,2,3,4])\n", "test_eq(list(dsrc[0]),[0,1,2,3,4])\n", "test_ne(dsrc,[0,1,2,3,5])\n", "test_eq(dsrc.get(2),2)\n", "test_eq(dsrc.get([1,2]),[1,2])\n", "test_eq(dsrc.get([True,False,False,True,False]),[0,3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#filts can be indices or boolean masks\n", "dsrc = DataSource(range(5), filts=[[0,2], [1,3,4]])\n", "test_eq(list(dsrc[0]),[0,2])\n", "test_eq(list(dsrc[1]),[1,3,4])\n", "#Subsets don't have to be disjoints\n", "dsrc = DataSource(range(5), filts=[[False,True,True,False,True], [True,False,False,True,True]])\n", "test_eq(list(dsrc[0]),[1,2,4])\n", "test_eq(list(dsrc[1]),[0,3,4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dsrc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#Base transform\n", "dsrc = DataSource(range(5), lambda x:x*2)\n", "test_eq(dsrc,[0,2,4,6,8])\n", "test_eq(list(dsrc[0]),[0,2,4,6,8])\n", "test_ne(dsrc,[1,2,4,6,8])\n", "test_eq(dsrc.get(2), 4)\n", "test_eq(dsrc.get([1,2]), [2,4])\n", "test_eq(dsrc.get([True,False,False,True,False]), [0,6])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#Different transforms for the two subsets\n", "dsrc = DataSource(range(5), Transform(lambda x: x*2, filt=1), filts=[[1,2],[0,3,4]])\n", "# test_eq(list(dsrc[0]),[1,2])\n", "test_eq(list(dsrc[1]),[0,6,8])\n", "test_eq(dsrc.get(2,1), 8)\n", "test_eq(dsrc.get([1,2], 1), [6,8])\n", "test_eq(dsrc.get([False,True], 0), [2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "def add(x, a=1): return x+a\n", "def multiply(x, a=2): return x*a\n", "def square(x): return x**2\n", "def add_undo(x, a=1): return x-a\n", "def multiply_undo(x, a=2): return x/a\n", "addt = Transform(add, add_undo, order=2)\n", "multt = Transform(multiply, multiply_undo, order=1)\n", "sqrt = Transform(square, order=0)\n", "\n", "#Test _order\n", "tfms = [addt,multt,sqrt]\n", "dsrc = DataSource([0,1,2,3], tfms, filts=[range(4)])\n", "test_eq(dsrc.get(2), ((2**2) * 2) + 1)\n", "\n", "#Test decode\n", "dsrc = DataSource([0,1,2,3], tfms, filts=[[0,1,2,3]])\n", "test_eq(dsrc.decode(9), (9-1)/2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper functions to create blocks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Get image files" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _get_files(p, fs, extensions=None):\n", " p = Path(p)\n", " res = [p/f for f in fs if not f.startswith('.')\n", " and ((not extensions) or f'.{f.split(\".\")[-1].lower()}' in extensions)]\n", " return res\n", "\n", "def get_files(path, extensions=None, recurse=False, include=None):\n", " \"Get all the files in `path` with optional `extensions`.\"\n", " path = Path(path)\n", " extensions = setify(extensions)\n", " extensions = {e.lower() for e in extensions}\n", " if recurse:\n", " res = []\n", " for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)\n", " if include is not None and i==0: d[:] = [o for o in d if o in include]\n", " else: d[:] = [o for o in d if not o.startswith('.')]\n", " res += _get_files(p, f, extensions)\n", " else:\n", " f = [o.name for o in os.scandir(path) if o.is_file()]\n", " res = _get_files(path, f, extensions)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "path = untar_data(URLs.MNIST_TINY)\n", "test_eq(len(get_files(path/'train'/'3')),346)\n", "test_eq(len(get_files(path/'train'/'3', extensions='.png')),346)\n", "test_eq(len(get_files(path/'train'/'3', extensions='.jpg')),0)\n", "test_eq(len(get_files(path/'train', extensions='.png')),0)\n", "test_eq(len(get_files(path/'train', extensions='.png', recurse=True)),709)\n", "test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])),709)\n", "test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])),729)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))\n", "\n", "def get_image_files(path, include=None, **kwargs):\n", " \"Get image files in `path` recursively.\"\n", " return get_files(path, extensions=image_extensions, recurse=True, include=include)\n", "\n", "def image_getter(suf='', **kwargs):\n", " def _inner(o, **kw): return get_image_files(o/suf, **{**kwargs,**kw})\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "path = untar_data(URLs.MNIST_TINY)\n", "test_eq(len(get_image_files(path)),1428)\n", "test_eq(len(get_image_files(path/'train')),709)\n", "test_eq(len(get_image_files(path, include='train')),709)\n", "test_eq(len(get_image_files(path, include=['train','valid'])),1408)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "path = untar_data(URLs.MNIST_TINY)\n", "test_eq(len(image_getter()(path)),1428)\n", "test_eq(len(image_getter('train')(path)),709)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def show_image(im, ax=None, figsize=None, title=None, **kwargs):\n", " \"Show a PIL image on `ax`.\"\n", " if ax is None: _,ax = plt.subplots(figsize=figsize)\n", " if isinstance(im,Tensor) and im.shape[0]<5: im=im.permute(1,2,0)\n", " ax.imshow(im, **kwargs)\n", " if title is not None: ax.set_title(title)\n", " ax.axis('off')\n", " return ax\n", "\n", "def show_title(o, ax=None):\n", " if ax is None: print(o)\n", " else: ax.set_title(o)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convention: a function that has the name of a verb and ends with *er* returns a function (to get transforms directly or for use in the high level API belox)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def random_splitter(valid_pct=0.2, seed=None, **kwargs):\n", " \"Split `items` between train/val with `valid_pct` randomly.\"\n", " def _inner(o, **kwargs):\n", " if seed is not None: torch.manual_seed(seed)\n", " rand_idx = torch.randperm(len(o))\n", " cut = int(valid_pct * len(o))\n", " return rand_idx[cut:],rand_idx[:cut]\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test\n", "trn,val = random_splitter(seed=42)([0,1,2,3,4,5])\n", "test_equal(trn, tensor([3, 2, 4, 1, 5]))\n", "test_equal(val, tensor([0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _grandparent_mask(items, name):\n", " return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]\n", "\n", "def grandparent_splitter(train_name='train', valid_name='valid', **kwargs):\n", " \"Split `items` from the grand parent folder names (`train_name` and `valid_name`).\"\n", " def _inner(o, **kwargs):\n", " return _grandparent_mask(o, train_name),_grandparent_mask(o, valid_name)\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_TINY)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test\n", "#With string filenames\n", "path = untar_data(URLs.MNIST_TINY)\n", "items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png', \n", " path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png', \n", " path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',\n", " path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']\n", "trn,val = grandparent_splitter()(items)\n", "test_eq(trn,[True,False,False,True,True,False,True,False])\n", "test_eq(val,[False,True,True,False,False,True,False,True])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def parent_label(o, **kwargs):\n", " \"Label `item` with the parent folder name.\"\n", " return o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1]\n", "\n", "def re_labeller(pat):\n", " \"Label `item` with regex `pat`.\"\n", " pat = re.compile(pat)\n", " def _inner(o, **kwargs):\n", " res = pat.search(str(o))\n", " assert res,f'Failed to find \"{pat}\" in \"{o}\"'\n", " return res.group(1)\n", " return _inner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pets DataSource" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's grab the Pets dataset first. Our `DataSource` will contain all the image files as items, and we'll randomly select two filts with 80% and 20% of the data. To get our xs, we need to apply a `Transform` that opens the image in the filenames. We'll call it `Imagify`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.PETS)/\"images\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PetTfm(Transform):\n", " def __init__(self, source):\n", " super().__init__()\n", " self.source,self.vocab = source,None\n", " self.labeller = re_labeller(pat = r'/([^/]+)_\\d+.jpg$')\n", " \n", " def setups(self, dsrc):\n", " vals = map(self.labeller, dsrc.train)\n", " self.vocab,self.o2i = uniqueify(vals, sort=True, bidir=True)\n", " \n", " def encodes(self, o):\n", " if self.vocab is None: return o\n", " return Image.open(o), self.o2i[self.labeller(o)]\n", "\n", " def decodes(self, o): return o[0],self.vocab[o[1]]\n", " def show(self, o, ax=None): show_image(o[0], ax, title=o[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = PetTfm(source)\n", "items = get_image_files(source)\n", "split_idx = random_splitter()(items)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets = DataSource(items, tfm, filts=split_idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To access an element we need to specify index/filter (the latter defaults to 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xy = pets.get(0,1); xy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can decode an element for display purposes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xyd = pets.decode(xy, 1); xyd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm.show(xyd)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _dsrc_show(self, o, filt=0, **kwargs): self.tfm.show(self.decode(o, filt), **kwargs)\n", "def _fl_show(self, o, **kwargs): self.dsrc.show(o, self.filt, **kwargs)\n", "def _fl_show_at(self, i, **kwargs): self.show(self[i], **kwargs)\n", "\n", "DataSource.show = _dsrc_show\n", "FilteredList.show = _fl_show\n", "FilteredList.show_at = _fl_show_at" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets.show(pets.get(0,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets.valid.show_at(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Image Transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we can batch our images, we'll need to apply some basic image transformations: converting to RGB, making them all the same size and also converting them to tensors. We have to get prepared for different kind of targets: sometimes the target won't be applied the transform, but sometimes it will and in different ways. We support images, segmentation masks, points or bounding boxes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "TfmY = Enum('TfmY', 'Mask Image Point Bbox No')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `ImageTransform` class seems ugly but it just dispatches the `apply` or `decode` function properly between `x` and `y`, and allow different implementations of each function. This will be very important for data augmentation in the next notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ImageTransform(Transform):\n", " \"Basic class for image transforms.\"\n", " _order,_tfm_y = 10,TfmY.No\n", " \n", " def randomize(self): pass\n", " \n", " def encodes(self, o, **kwargs):\n", " self.x,*y = o\n", " self.randomize() # Ensure we have the same state for x and y\n", " return ( self.apply(self.x), *( self.apply_y(y_, **kwargs) for y_ in y))\n", " \n", " def decodes(self, o, **kwargs):\n", " self.x,*y = o\n", " return (self.unapply(self.x), *(self.unapply_y(y_, **kwargs) for y_ in y))\n", "\n", " def _tfm_name(self, is_decode=False):\n", " return f\"{'un' if is_decode else ''}apply_{self._tfm_y.name.lower()}\"\n", "\n", " def apply_no (self, y): return y\n", " def apply_image(self, y): return self.apply(y)\n", " def apply_mask (self, y): return self.apply_image(y)\n", " def apply_point(self, y): return y\n", " def apply_bbox (self, y): return self.apply_point(y)\n", "\n", " def unapply_no (self, y): return y\n", " def unapply_image(self, y): return self.unapply(y)\n", " def unapply_mask (self, y): return self.unapply_image(y)\n", " def unapply_point(self, y): return y\n", " def unapply_bbox (self, y): return self.unapply_point(y)\n", " \n", " def apply (self, x): return x\n", " def unapply (self, x): return x\n", " def apply_y(self, y): return getattr(self, self._tfm_name(False))(y)\n", " def unapply_y(self, y): return getattr(self, self._tfm_name(True ))(y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "import random\n", "class FakeTransform(ImageTransform):\n", " def randomize(self): self.a = random.randint(1,1000)\n", " def apply(self, x): return x + self.a\n", " def apply_mask(self, x): return x + 5\n", " def apply_point(self, x): return x + 2\n", "\n", "tfm = FakeTransform()\n", "xy = x,y = 5,10\n", "#Basic behavior: x has changed, not y\n", "t1 = tfm(xy)\n", "assert t1[0]!=x and t1[1]==y, t1\n", "#Check the same random integer was used for x and y when transforming y\n", "tfm._tfm_y=TfmY.Image; t1 = tfm(xy)\n", "test_eq(t1[0] - 5,t1[1] - 10)\n", "#Check mask, point,bbox implementations\n", "tfm._tfm_y=TfmY.Mask ; test_eq(tfm(xy)[1],15)\n", "tfm._tfm_y=TfmY.Point; test_eq(tfm(xy)[1],12)\n", "tfm._tfm_y=TfmY.Bbox ; test_eq(tfm(xy)[1],12)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our first transform decodes an image to 'RGB'. We can specify different modes for the xs and ys, the default is 'RGB' for x, then mode_x for y if our ys are images, 'L' for y if our ys are segmentation masks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DecodeImg(ImageTransform):\n", " \"Convert regular image to RGB, masks to L mode.\"\n", " def __init__(self, mode_x='RGB', mode_y=None): self.mode_x,self.mode_y = mode_x,mode_y\n", " def apply(self, x): return x.convert(self.mode_x)\n", " def apply_image(self, y): return y.convert(ifnone(self.mode_y,self.mode_x))\n", " def apply_mask(self, y): return y.convert(ifnone(self.mode_y,'L'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our second transform resizes an image, using a given mode. It defaults to bilinear for images and nearest for segmentation masks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ResizeFixed(ImageTransform):\n", " \"Resize image to `size` using `mode_x` (and `mode_y` on targets).\"\n", " _order=15\n", " def __init__(self, size, mode_x=Image.BILINEAR, mode_y=None):\n", " if isinstance(size,int): size=(size,size)\n", " self.size = (size[1],size[0]) #PIL takes size in the other way round\n", " self.mode_x,self.mode_y = mode_x,mode_y\n", " \n", " def apply(self, x): return x.resize(self.size, self.mode_x)\n", " def apply_image(self, y): return y.resize(self.size, ifnone(self.mode_y,self.mode_x))\n", " def apply_mask(self, y): return y.resize(self.size, ifnone(self.mode_y,Image.NEAREST))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The transformation to tensors is done in two steps just in case one wants to apply transforms to byte tensors. The permutation of axes needs to be reversed for display, so we have an `unapply` function (which is what is called by `decode` in an `ImageTransform`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ToByteTensor(ImageTransform):\n", " \"Transform our items to byte tensors.\"\n", " _order=20\n", " def apply(self, x):\n", " res = torch.ByteTensor(torch.ByteStorage.from_buffer(x.tobytes()))\n", " w,h = x.size\n", " return res.view(h,w,-1).permute(2,0,1)\n", " \n", " def unapply(self, x):\n", " return x[0] if x.shape[0] == 1 else x.permute(1,2,0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly we convert our tensors to floats (or ints for segmentation masks) and divides by 255 (can specify a different value and a `div_y`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ToFloatTensor(ImageTransform):\n", " \"Transform our items to float tensors (int in the case of mask).\"\n", " _order=5 #Need to run after CUDA on the GPU\n", " def __init__(self, div_x=255., div_y=None): self.div_x,self.div_y = div_x,div_y\n", " def apply(self, x): return x.float().div_(self.div_x)\n", " def apply_mask(self, x): \n", " return x.long() if self.div_y is None else x.long().div_(self.div_y)\n", " \n", " def unapply(self, x): return torch.clamp(x, 0., 1.)\n", " def unapply_mask(self, x): return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's test it's all work properly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [PetTfm(source), DecodeImg(), ResizeFixed(128), ToByteTensor()]\n", "pets = DataSource(items, tfms, filts=split_idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xy = pets.get(0,1); xy[0].type()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets.show(xy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DataBunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With transforms to make our images tensors of the same size, we're ready to create batches and dataloaders. We wrap a PyTorch dataloader to add batch transforms. Additional kwargs will be passed along. Those transforms can be decoded like before, for display purposes (like normalization)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def apply_all(o, fs, fname=None, **kwargs):\n", " for f in fs:\n", " if fname is not None: f = getattr(f,fname,noop)\n", " o = f(o, **kwargs)\n", " return o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class TfmDataLoader():\n", " def __init__(self, dl, tfms=None, **tfm_kwargs):\n", " self.dl,self.tfms,self.tfm_kwargs = dl,order_sorted(tfms),tfm_kwargs\n", " \n", " def __len__(self): return len(self.dl)\n", " def __iter__(self):\n", " for b in self.dl: yield apply_all(b, self.tfms)\n", " \n", " def one_batch(self): return next(iter(self))\n", " def decode_batch(self): return self.decode(self.one_batch())\n", " def decode(self, o): return apply_all(o, reversed(self.tfms), fname='decode')\n", " \n", " def __getattr__(self, k):\n", " try: return getattr(self.dataset, k)\n", " except AttributeError: raise AttributeError(k) from None\n", " \n", " @property\n", " def dataset(self): return self.dl.dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we add a basic function to create dataloaders from a `DataSource`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from torch.utils.data.dataloader import DataLoader\n", "\n", "def get_dl(dset, bs=64, tfms=None, tfm_kwargs=None, **kwargs):\n", " dl = DataLoader(dset, bs, **kwargs)\n", " return TfmDataLoader(dl, tfms=tfms, **(ifnone(tfm_kwargs,{})))\n", "\n", "def get_dls(dsrc, bs=64, tfms=None, tfm_kwargs=None, **kwargs):\n", " return [get_dl(dsrc[i], bs, shuffle=i==0, tfms=tfms, tfm_kwargs=tfm_kwargs, **kwargs)\n", " for i in range_of(dsrc)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = get_dls(pets, tfms=ToFloatTensor())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a convenience function to grab the k-th item in a batch, even if the batch is constituted of lists of tensors." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def grab_item(b,k):\n", " if isinstance(b, (list,tuple)): return [grab_item(o,k) for o in b]\n", " return b[k]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_batch(b, show, items=9, cols=3, figsize=None, show_func=None, **kwargs):\n", " rows = (items+cols-1) // cols\n", " if figsize is None: figsize = (cols*3, rows*3)\n", " fig,axs = plt.subplots(rows, cols, figsize=figsize)\n", " for k,ax in enumerate(axs.flatten()):\n", " show(grab_item(b,k), ax=ax, show_func=show_func, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class DataBunch():\n", " \"Basic wrapper around several `DataLoader`s.\"\n", " def __init__(self, *dls): self.dls = dls\n", " def __getitem__(self, i): return self.dls[i]\n", " \n", " @property\n", " def train_dl(self): return self.dls[0]\n", " @property\n", " def valid_dl(self): return self.dls[1]\n", " @property\n", " def train_ds(self): return self.train_dl.dataset\n", " @property\n", " def valid_ds(self): return self.valid_dl.dataset\n", "\n", " def show_batch(self, i=0, items=9, cols=3, figsize=None, show_func=None, **kwargs):\n", " b = self.dls[i].decode_batch()\n", " show = self[i].dataset.show\n", " show_batch(b, show, items, cols, figsize=figsize, show_func=show_func, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch(*dls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = data[0].one_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x.shape,x.type(),y.shape,y.type()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally let's monkey-patch a `databunch` function in `DataSource` to quickly create a `databunch`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _dsrc_databunch(self, bs=64, tfms=None, **kwargs):\n", " res = DataBunch(*get_dls(self, bs=bs, tfms=tfms, **kwargs))\n", " res.dsrc = self\n", " return res\n", "\n", "DataSource.databunch = _dsrc_databunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cuda" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda',0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from fastai.torch_core import to_device, to_cpu\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Cuda(Transform):\n", " _order = 0\n", " def __init__(self,device): self.device=device\n", " def encodes(self, b, tfm_y=TfmY.No): return to_device(b, self.device)\n", " def decodes(self, b): return to_cpu(b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Normalization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see other batch transforms in the next chapter but one that is pretty common is normalization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Normalize(Transform):\n", " _order=99\n", " def __init__(self, mean, std, do_x=True, do_y=False):\n", " self.mean,self.std,self.do_x,self.do_y = mean,std,do_x,do_y\n", " \n", " def encodes(self, b):\n", " x,y = b\n", " if self.do_x: x = self.normalize(x)\n", " if self.do_y: y = self.normalize(y)\n", " return x,y\n", " \n", " def decodes(self, b):\n", " x,y = b\n", " if self.do_x: x = self.denorm(x)\n", " if self.do_y: y = self.denorm(y)\n", " return x,y\n", " \n", " def normalize(self, x): return (x - self.mean) / self.std\n", " def denorm(self, x): return x * self.std + self.mean" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean,std = tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda(),tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda()\n", "data = pets.databunch(tfms = [Cuda(device), ToFloatTensor(), Normalize(mean,std)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x ,y = data[0].one_batch()\n", "xd,yd = data[0].decode_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x.type(), xd.type(), x.mean(), x.std(), xd.mean(), xd.std()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pets 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor()]\n", "dl_tfms = [Cuda(device), ToFloatTensor()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Imagify(Transform):\n", " def __init__(self, f=Image.open, cmap=None, alpha=1.): self.f,self.cmap,self.alpha = f,cmap,alpha\n", " def encodes(self, fn): return Image.open(fn)\n", " def show(self, im, ax=None, figsize=None, cmap=None, alpha=None):\n", " return show_image(im, ax, figsize=figsize,\n", " cmap=ifnone(cmap,self.cmap),\n", " alpha=ifnone(alpha,self.alpha))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get our ys, we'll need to apply the re pattern (function `re_labeller` from before) and a transform that creates the list of classes, then maps categories to their indices. We call that transform `Categorize`. It needs a `setup` to create the classes, and it's reversible so we implement `decode`. This is also a base a transform so we implement its `show` method.\n", "\n", "Since it needs to run after the labelling transform (here our re pattern labeller) we give it an `_order` of 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Categorize(Transform):\n", " _order=1\n", " def __init__(self): self.o2i = None\n", " def encodes(self,o): return self.o2i[o] if self.o2i else o\n", " def decodes(self, o): return self.vocab[o]\n", " def show(self, o, ax=None): show_title(o, ax)\n", " def setups(self, dsrc): self.vocab,self.o2i = uniqueify(dsrc.train, sort=True, bidir=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "labeller = re_labeller(pat = r'/([^/]+)_\\d+.jpg$')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A transform that is applied to define a base object can have a `show` method: for instance the transform that opens an `Image` has a show method. When trying to display objects, the API will decode it and grab the first transform that provides a `show` method (this can be overriden by passing a custom `show` but we'll see that later).\n", "\n", "The `show_xs` function is there to combine the show methods of the base transforms to display x and y together. We can either pass a transform that has a show method or a custom list of `show` methods." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def show_xs(xs, shows, ax=None, **kwargs):\n", " for x,show in zip(xs,shows):\n", " # can pass func or obj with a `show` method\n", " show = getattr(show, 'show', show)\n", " ax = show(x, ax=ax, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DsrcTfm():\n", " def __init__(self, ttfms, tfm=noop):\n", " self.ttfms = [Transforms(tfm) for tfm in listify(ttfms)]\n", " self.tfm,self.activ,self.done_setup = Transforms(tfm),None,False\n", "\n", " def __call__(self, o, **kwargs):\n", " if self.activ: return self.activ(o, **kwargs)\n", " o = [t(o, **kwargs) for t in self.ttfms]\n", " return self.tfm(o, **kwargs)\n", " \n", " def decode(self, o, **kwargs):\n", " o = self.tfm.decode(o, **kwargs)\n", " return [t.decode(p, **kwargs) for p,t in zip(o,self.ttfms)]\n", "\n", " def setup(self, dsrc):\n", " if self.done_setup: return\n", " for tfm in self.ttfms:\n", " self.activ = tfm\n", " tfm.setup(dsrc)\n", " self.activ=None\n", " self.tfm.setup(dsrc)\n", " self.done_setup = True\n", " \n", " def show(self, o, **kwargs): return show_xs(o, self.ttfms, **kwargs)\n", " def __repr__(self): return f'DsrcTfm({self.ttfms}\\n{self.tfm})\\n'\n", " \n", " @property\n", " def xt(self): return self.ttfms[0]\n", " @property\n", " def yt(self): return self.ttfms[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfm = DsrcTfm([Imagify(), [labeller,Categorize()]], ds_tfms)\n", "pets = DataSource(items, tfm=tfm, filts=split_idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tfm_dsrc(items, filts, xt, yt, labeller, ds_tfms=None):\n", " tfm = DsrcTfm([xt, [labeller,yt]], ds_tfms)\n", " return DataSource(items, tfm=tfm, filts=filts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets = tfm_dsrc(items, split_idx, Imagify(), Categorize(), labeller, ds_tfms=ds_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xy = pets.decoded(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_xs(xy, tfm.ttfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets.train.show_at(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = pets.databunch(bs=16, tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Higher level API" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DataBlock():\n", " @staticmethod\n", " def get_items(source): raise NotImplementedError\n", " @staticmethod\n", " def split(items): raise NotImplementedError\n", " @staticmethod\n", " def label_func(item): raise NotImplementedError\n", " \n", " def __init__(self, source):\n", " self.source = source\n", " xt,yt = self.types()\n", " self.tfm = DsrcTfm([xt, [self.__class__.label_func,yt]])\n", "\n", " def datasource(self, tfms=None):\n", " items = self.__class__.get_items(self.source, self=self)\n", " split_idx = self.__class__.split(items, self=self)\n", " return DataSource(items, [self.tfm]+listify(tfms), filts=split_idx)\n", " \n", " def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, **kwargs):\n", " return self.datasource(tfms=ds_tfms).databunch(bs, tfms=dl_tfms, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PetsData(DataBlock):\n", " def types(self): return Imagify(),Categorize()\n", " get_items = image_getter()\n", " split = random_splitter()\n", " label_func = re_labeller(pat = r'/([^/]+)_\\d+.jpg$')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.PETS)/\"images\"\n", "dblk = PetsData(source)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.dsrc.train.show_at(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "' '.join(dblklk.tfm.yt.vocab)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Try different data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### MNIST" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MnistData(DataBlock):\n", " def types(self): return Imagify(),Categorize()\n", " get_items = get_image_files\n", " split = grandparent_splitter(train_name='training', valid_name='testing')\n", " label_func = parent_label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.MNIST)\n", "data = MnistData(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are seveal ways to get the show display properly our images. First we can pass a custom `show_func` method and change the function that shows the `x`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(shows = (partial(show_image, cmap='gray'), None))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or just set the default `cmap` to gray in `types`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MnistDataBW(MnistData):\n", " def types(self): return Imagify(cmap='gray'),Categorize()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = MnistDataBW(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Planet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.PLANET_SAMPLE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def onehot(x, c, a=1.):\n", " \"Return the `a`-hot encoded tensor for `x` with `c` classes.\"\n", " res = torch.zeros(c)\n", " if a<1: res += (1-a)/(c-1)\n", " res[x] = a\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "test_equal(onehot(1,5), tensor([0.,1.,0.,0.,0.]))\n", "test_equal(onehot([1,3],5), tensor([0.,1.,0.,1.,0.]))\n", "test_equal(onehot(tensor([1,3]),5), tensor([0.,1.,0.,1.,0.]))\n", "test_equal(onehot([True,False,True,True,False],5), tensor([1.,0.,1.,1.,0.]))\n", "test_equal(onehot([],5), tensor([0.,0.,0.,0.,0.]))\n", "\n", "test_equal(onehot(1,5,0.9), tensor([0.025,0.9,0.025,0.025,0.025]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class MultiCategorize(Transform):\n", " _order=1\n", " def __init__(self): self.vocab = None\n", " def __call__(self,x): return [self.o2i[o] for o in x if o in self.o2i]\n", " def decode(self, o): return [self.vocab[i] for i in o]\n", " @property\n", " def c(self): return len(self.vocab)\n", " def show(self, o, ax=None): \n", " (print if ax is None else ax.set_title)(';'.join(o))\n", " \n", " def setup(self, dsrc):\n", " if self.vocab is not None: return\n", " vals = set()\n", " for c in dsrc.train: vals = vals.union(set(c))\n", " self.vocab,self.o2i = uniqueify(list(vals), sort=True, bidir=True)\n", " \n", "class OneHotEncode(Transform):\n", " _order=10\n", " def setup(self, items): self.c = items.activ_tfm.c\n", " def __call__(self, o): return onehot(o, self.c) if self.c is not None else o\n", " def decode(self, o): return [i for i,x in enumerate(o) if x == 1]\n", " \n", "def multi_category(): return [MultiCategorize(), OneHotEncode()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "tfm = MultiCategorize()\n", "#Even if 'c' is the first class, vocab is sorted for reproducibility\n", "ds = DataSource([['c','a'], ['a','b'], ['b'], []], [tfm], filts=[[0,1,2,3], []])\n", "test_eq(tfm.vocab,['a','b','c'])\n", "test_eq(tfm(['b','a']),[1,0])\n", "test_eq(tfm.decode([2,0]),['c','a'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def get_str_column(df, col_name, prefix='', suffix='', delim=None):\n", " \"Read `col_name` in `df`, optionnally adding `prefix` or `suffix`.\"\n", " values = df[col_name].values.astype(str)\n", " values = np.char.add(np.char.add(prefix, values), suffix)\n", " if delim is not None:\n", " values = np.array(list(csv.reader(values, delimiter=delim)))\n", " return values" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "df = pd.DataFrame({'a': ['cat', 'dog', 'car'], 'b': ['a b', 'c d', 'a e']})\n", "test_np_eq(get_str_column(df, 'a'), np.array(['cat', 'dog', 'car']))\n", "test_np_eq(get_str_column(df, 'a', prefix='o'), np.array(['ocat', 'odog', 'ocar']))\n", "test_np_eq(get_str_column(df, 'a', suffix='.png'), np.array(['cat.png', 'dog.png', 'car.png']))\n", "test_np_eq(get_str_column(df, 'b', delim=' '), np.array([['a','b'], ['c','d'], ['a','e']]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PlanetData(DataBlock):\n", " def types(self): return Imagify(),multi_category()\n", " \n", " def get_items(source, self): \n", " df = pd.read_csv(self.source/'labels.csv')\n", " items = get_str_column(df, 'image_name', prefix=f'{self.source}/train/', suffix='.jpg')\n", " labels = get_str_column(df, 'tags', delim=' ')\n", " self.item2label = {i:s for i,s in zip(items,labels)}\n", " return items\n", " \n", " split = random_splitter()\n", " def label_func(item, self): return self.item2label[item]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.PLANET_SAMPLE)\n", "dsrc = PlanetData(source)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = dsrc.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Camvid" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SegmentMask(Item):\n", " tfm = partial(Imagify, cmap='tab20', alpha=0.5)\n", " tfm_kwargs = {'tfm_y': TfmY.Mask}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CamvidData(DataBlock):\n", " types = Image,SegmentMask\n", " get_items = image_getter('images')\n", " split = random_splitter()\n", " label_func = lambda o,self: self.source/'labels'/f'{o.stem}_P{o.suffix}'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.CAMVID_TINY)\n", "data = CamvidData(source).databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(cmap='tab20')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Biwii" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class PointScaler(Transform):\n", " _order = 5 #Run before we apply any ImageTransform\n", " def __init__(self, do_scale=True, y_first=False): \n", " self.do_scale,self.y_first = do_scale,y_first\n", " \n", " def __call__(self, o, tfm_y=TfmY.No):\n", " x,y = o\n", " if not isinstance(y, torch.Tensor): y = tensor(y)\n", " y = y.view(-1, 2).float()\n", " if not self.y_first: y = y.flip(1)\n", " if self.do_scale: y = y * 2/tensor(list(x.size)).float() - 1\n", " return (x,y)\n", " \n", " def decode(self, o, tfm_y=TfmY.No):\n", " x,y = o\n", " y = y.flip(1)\n", " y = (y+1) * tensor([x.shape[:2]]).float()/2\n", " return (x,y)\n", "\n", "class PointShow(Transform):\n", " def show(self, x, ax=None): ax.scatter(x[:, 1], x[:, 0], s=10, marker=.', c=r')\n", "\n", "class Points(Item):\n", " tfm,tfm_ds,tfm_kwargs = PointShow,PointScaler,{'tfm_y': TfmY.Point}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BiwiData(DataBlock):\n", " types = Image,Points\n", " def __init__(self, source, *args, **kwargs):\n", " super().__init__(source, *args, **kwargs)\n", " self.fn2ctr = pickle.load(open(source/'centers.pkl', 'rb'))\n", " \n", " get_items = image_getter('images')\n", " split = random_splitter()\n", " label_func = lambda o,self: self.fn2ctr[o.name]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dblk = BiwiData(untar_data(URLs.BIWI_SAMPLE))\n", "data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Coco" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "from fastai.vision.data import get_annotations\n", "from matplotlib import patches, patheffects\n", "\n", "def _draw_outline(o, lw):\n", " o.set_path_effects([patheffects.Stroke(linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def _draw_rect(ax, b, color='white', text=None, text_size=14, hw=True, rev=False):\n", " lx,ly,w,h = b\n", " if rev: lx,ly,w,h = ly,lx,h,w\n", " if not hw: w,h = w-lx,h-ly\n", " patch = ax.add_patch(patches.Rectangle((lx,ly), w, h, fill=False, edgecolor=color, lw=2))\n", " _draw_outline(patch, 4)\n", " if text is not None:\n", " patch = ax.text(lx,ly, text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')\n", " _draw_outline(patch,1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class BBoxScaler(PointScaler):\n", " def __call__(self, o, tfm_y=TfmY.Bbox): \n", " x,y = o\n", " return x, (super().__call__((x,y[0]))[1].view(-1,4),y[1])\n", " \n", " def decode(self, o, tfm_y=TfmY.Bbox): \n", " x,y = o\n", " _,bbox = super().decode((x,y[0].view(-1,2)))\n", " return x, (bbox.view(-1,4),y[1])\n", " \n", "class BBoxencodes(Transform):\n", " _order=1\n", " def __init__(self): self.vocab = None\n", " \n", " def __call__(self,o):\n", " x,y = o\n", " return (x,[self.otoi[o_] for o_ in y if o_ in self.otoi])\n", " \n", " def decode(self, o):\n", " x,y = o\n", " return x, [self.vocab[i] for i in y]\n", " \n", " def setup(self, dsrc):\n", " if self.vocab is not None: return\n", " vals = set()\n", " for bb,c in dsrc.train: vals = vals.union(set(c))\n", " self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')\n", " \n", " def show(self, x, ax):\n", " bbox,label = x\n", " for b,l in zip(bbox, label): \n", " if l != '#bg': _draw_rect(ax, b, hw=False, rev=True, text=l)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class BBox(Item): tfm,tfm_ds,tfm_kwargs = BBoxencodes,BBoxScaler,{'tfm_y': TfmY.Bbox}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def bb_pad_collate(samples, pad_idx=0):\n", " max_len = max([len(s[1][1]) for s in samples])\n", " bboxes = torch.zeros(len(samples), max_len, 4)\n", " labels = torch.zeros(len(samples), max_len).long() + pad_idx\n", " imgs = []\n", " for i,s in enumerate(samples):\n", " imgs.append(s[0][None])\n", " bbs, lbls = s[1]\n", " if not (bbs.nelement() == 0):\n", " bboxes[i,-len(lbls):] = bbs\n", " labels[i,-len(lbls):] = tensor(lbls)\n", " return torch.cat(imgs,0), (bboxes,labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CocoData(DataBlock):\n", " types = Image,BBox\n", " def __init__(self, source, *args, **kwargs):\n", " super().__init__(source, *args, **kwargs)\n", " images, lbl_bbox = get_annotations(source/'train.json')\n", " self.img2bbox = dict(zip(images, lbl_bbox))\n", " \n", " get_items = image_getter('train')\n", " split = random_splitter()\n", " label_func = lambda o,self: self.img2bbox[o.name]\n", " \n", " def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):\n", " return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs,\n", " collate_fn=bb_pad_collate, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.COCO_TINY)\n", "dblk = CocoData(source)\n", "data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use of the low-level API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also use `DataSource` directly with just one transform that does everything, without using the blocks. You will have to provide your show method if you want to use `show_batch` however (no need to decode if you do everything in one transform)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def size_f(x): return tensor(x.size).float()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.COCO_TINY)\n", "fns, lbl_bbox = get_annotations(path/'train.json')\n", "img2bbox = dict(zip(fns, lbl_bbox))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CocoTransform(Transform):\n", " def __init__(self): self.vocab = None\n", " def setup(self, data):\n", " if self.vocab is not None: return\n", " vals = set()\n", " for c in data.train: vals = vals.union(set(img2bbox[c.name][1]))\n", " self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')\n", " \n", " def __call__(self, fn):\n", " img = Image.open(fn)\n", " bbox,lbl = img2bbox[fn.name]\n", " #flip and rescale to -1,1\n", " bbox = tensor(bbox).view(-1,2).flip(1) * 2/size_f(img) - 1\n", " lbl = [self.otoi[l] for l in lbl if l in self.otoi]\n", " return (img, [bbox.view(-1,4), lbl])\n", " \n", " def show(self, o, ax):\n", " img,(bbox,lbl) = o\n", " show_image(img, ax)\n", " lbl = [self.vocab[l] for l in lbl if l != 0] #Unpad and decode\n", " bbox = bbox[-len(lbl):,] #Unpad\n", " bbox = (bbox.view(-1,2) + 1) * tensor(img.shape[:2]).float() / 2\n", " bbox = bbox.flip(1).view(-1,4)\n", " for b,l in zip(bbox, lbl): _draw_rect(ax, b, hw=False, rev=True, text=l)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fnames = get_image_files(path)\n", "splits = random_splitter()(fnames)\n", "ds = DataSource(fnames, tfms=CocoTransform(), filts=splits, tfm_y=TfmY.Bbox)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds = ds.transformed(tfms=ds_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch(*get_dls(ds, 16, collate_fn=bb_pad_collate, tfms=dl_tfms))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! python notebook2script.py \"200_datablock_config.ipynb\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }