{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from fastai.datasets import URLs, untar_data\n", "from pathlib import Path\n", "import pandas as pd, numpy as np, torch, re, PIL, os, mimetypes, csv, itertools\n", "import matplotlib.pyplot as plt\n", "from collections import OrderedDict\n", "from typing import *\n", "from enum import Enum\n", "from functools import partial,reduce\n", "from torch import as_tensor,Tensor\n", "from IPython.core.debugger import set_trace" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Core functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def ifnone(a, b): return b if a is None else a\n", "def noop ( x, *args, **kwargs): return x\n", "def noops(self, x, *args, **kwargs): return x\n", "def range_of(x): return list(range(len(x)))\n", "torch.Tensor.ndim = property(lambda x: x.dim())\n", "\n", "import operator\n", "\n", "def test(a,b,cmp,cname=None):\n", " if cname is None: cname=cmp.__name__\n", " assert cmp(a,b),f\"{cname}:\\n{a}\\n{b}\"\n", "\n", "def test_eq(a,b): test(a,b,operator.eq,'==')\n", "def test_ne(a,b): test(a,b,operator.ne,'!=')\n", "def test_equal(a,b): test(a,b,torch.equal,'==')\n", "def test_np_eq(a,b): test(a,b,np.array_equal,'==')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "test_eq(noop(1),1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def listify(o):\n", " \"Make `o` a list.\"\n", " if o is None: return []\n", " if isinstance(o, list): return o\n", " if isinstance(o, str): return [o]\n", " if not isinstance(o, Iterable): return [o]\n", " #Rank 0 tensors in PyTorch are Iterable but don't have a length.\n", " try: a = len(o)\n", " except: return [o]\n", " return list(o)\n", "\n", "def tuplify(o):\n", " \"Make `o` a tuple.\"\n", " return tuple(listify(o))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def compose(*funcs): return reduce(lambda f,g: lambda x: f(g(x)), reversed(funcs), noop)\n", "def is_listy(x:Any)->bool: return isinstance(x, (tuple,list))\n", "\n", "def tensor(x, *rest):\n", " \"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly.\"\n", " if len(rest): x = tuplify(x)+rest\n", " # Pytorch bug in dataloader using num_workers>0\n", " if is_listy(x) and len(x)==0: return tensor(0)\n", " res = torch.tensor(x) if is_listy(x) else as_tensor(x)\n", " if res.dtype is torch.int32:\n", " warn('Tensor is int32: upgrading to int64; for better performance use int64 input')\n", " return res.long()\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "test_eq(listify(None),[])\n", "test_eq(listify([1,2,3]),[1,2,3])\n", "test_ne(listify([1,2,3]),[1,2,])\n", "test_eq(listify('abc'),['abc'])\n", "test_eq(listify(range(0,3)),[0,1,2])\n", "test_eq(listify(tensor(0)),[tensor(0)])\n", "test_eq(listify([tensor(0),tensor(1)]),[tensor(0),tensor(1)])\n", "test_eq(listify(tensor([0.,1.1])),[0,1.1])\n", "\n", "test_eq(tuplify(None),())\n", "test_eq(tuplify([1,2,3]),(1,2,3))\n", "test_eq(tuplify(tensor([0.,1.1])),(0,1.1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "from inspect import getfullargspec\n", "\n", "def has_param(func, p):\n", " \"Check if `func` accepts `p` as argument.\"\n", " return p in getfullargspec(func).args\n", "\n", "def feed_kwargs(func, *args, **kwargs):\n", " \"Feed `args` and the `kwargs` `func` accepts to `func`.\"\n", " signature = getfullargspec(func)\n", " if signature.varkw is not None: return func(*args, **kwargs)\n", " passed_kwargs = {k:v for k,v in kwargs.items() if k in signature.args}\n", " return func(*args, **passed_kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test\n", "def test_func(a, b, x=2): return a+b+x\n", "test_eq([has_param(test_func, p) for p in ['a', 'c', 'x']], [True,False,True])\n", "test_eq(feed_kwargs(test_func, 1, 2, x=3), 6)\n", "test_eq(feed_kwargs(test_func, 1, 2, y=3), 5)\n", "\n", "def test_func(a, b, x=2, **kwargs): return a+b+x\n", "test_eq(feed_kwargs(test_func, 1, 2, x=3), 6)\n", "test_eq(feed_kwargs(test_func, 1, 2, y=3), 5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def order_sorted(funcs, order_key='_order'):\n", " \"Listify `funcs` and sort it with `order_key`.\"\n", " key = lambda o: getattr(o, order_key, 0)\n", " return sorted(listify(funcs), key=key)\n", "\n", "def apply_all(x, funcs, *args, order_key='_order', filter_kwargs=False, **kwargs):\n", " \"Apply all `funcs` to `x` in order, pass along `args` and `kwargs`.\"\n", " for f in order_sorted(funcs, order_key=order_key): \n", " x = feed_kwargs(f, x, *args, **kwargs) if filter_kwargs else f(x, *args, **kwargs)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "# basic behavior\n", "def _test_f1(x, a=2): return x**a\n", "def _test_f2(x, a=2): return a*x\n", "test_eq(apply_all(2, [_test_f1, _test_f2]),8)\n", "# order\n", "_test_f1._order = 1\n", "test_eq(apply_all(2, [_test_f1, _test_f2]),16)\n", "#args\n", "test_eq(apply_all(2, [_test_f1, _test_f2], 3),216)\n", "#kwargs\n", "test_eq(apply_all(2, [_test_f1, _test_f2], a=3),216)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def mask2idxs(mask): return [i for i,m in enumerate(mask) if m]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def uniqueify(x, sort=False, bidir=False, start=None):\n", " \"Return the unique elements in `x`, optionally `sort`-ed, optionally return the reverse correspondance.\"\n", " res = list(OrderedDict.fromkeys(x).keys())\n", " if start is not None: res = listify(start)+res\n", " if sort: res.sort()\n", " if bidir: return res, {v:k for k,v in enumerate(res)}\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "test_eq(set(uniqueify([1,1,0,5,0,3])),{0,1,3,5})\n", "test_eq(uniqueify([1,1,0,5,0,3], sort=True),[0,1,3,5])\n", "v,o = uniqueify([1,1,0,5,0,3], bidir=True)\n", "test_eq(v,[1,0,5,3])\n", "test_eq(o,{1:0, 0: 1, 5: 2, 3: 3})\n", "v,o = uniqueify([1,1,0,5,0,3], sort=True, bidir=True)\n", "test_eq(v,[0,1,3,5])\n", "test_eq(o,{0:0, 1: 1, 3: 2, 5: 3})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def setify(o): return o if isinstance(o,set) else set(listify(o))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "test_eq(setify(None),set())\n", "test_eq(setify('abc'),{'abc'})\n", "test_eq(setify([1,2,2]),{1,2})\n", "test_eq(setify(range(0,3)),{0,1,2})\n", "test_eq(setify({1,2}),{1,2})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def onehot(x, c, a=1.):\n", " \"Return the `a`-hot encoded tensor for `x` with `c` classes.\"\n", " res = torch.zeros(c)\n", " if a<1: res += (1-a)/(c-1)\n", " res[x] = a\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "test_equal(onehot(1,5), tensor([0.,1.,0.,0.,0.]))\n", "test_equal(onehot([1,3],5), tensor([0.,1.,0.,1.,0.]))\n", "test_equal(onehot(tensor([1,3]),5), tensor([0.,1.,0.,1.,0.]))\n", "test_equal(onehot([True,False,True,True,False],5), tensor([1.,0.,1.,1.,0.]))\n", "test_equal(onehot([],5), tensor([0.,0.,0.,0.,0.]))\n", "\n", "test_equal(onehot(1,5,0.9), tensor([0.025,0.9,0.025,0.025,0.025]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _get_files(p, fs, extensions=None):\n", " p = Path(p)\n", " res = [p/f for f in fs if not f.startswith('.')\n", " and ((not extensions) or f'.{f.split(\".\")[-1].lower()}' in extensions)]\n", " return res\n", "\n", "def get_files(path, extensions=None, recurse=False, include=None):\n", " \"Get all the files in `path` with optional `extensions`.\"\n", " path = Path(path)\n", " extensions = setify(extensions)\n", " extensions = {e.lower() for e in extensions}\n", " if recurse:\n", " res = []\n", " for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)\n", " if include is not None and i==0: d[:] = [o for o in d if o in include]\n", " else: d[:] = [o for o in d if not o.startswith('.')]\n", " res += _get_files(p, f, extensions)\n", " else:\n", " f = [o.name for o in os.scandir(path) if o.is_file()]\n", " res = _get_files(path, f, extensions)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "path = untar_data(URLs.MNIST_TINY)\n", "test_eq(len(get_files(path/'train'/'3')),346)\n", "test_eq(len(get_files(path/'train'/'3', extensions='.png')),346)\n", "test_eq(len(get_files(path/'train'/'3', extensions='.jpg')),0)\n", "test_eq(len(get_files(path/'train', extensions='.png')),0)\n", "test_eq(len(get_files(path/'train', extensions='.png', recurse=True)),709)\n", "test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])),709)\n", "test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])),729)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data block API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataSource" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`DataSource` is the base class of the data blok API and is defined from `items`, `tfms` and `filters`. It can represent multiple datasets (train, valid, or more) that are contained in the `items`: each element of `filters` is a boolean mask or a collection of ints that says which `items` are in which dataset.\n", "\n", "When accessing an element, `tfms` are applied to it with optional `tfm_kwargs` passed along. Those kwargs are filtered so that each `tfms` only gets the one it accepts. At its base a `tfm` is just a simple function (open an image, resize it, one-hot encode a category, etc.) but it can be more complex (see `Transform` class below). Some transforms need a setup (for instance the transform that changes a cateogry to its index needs to compute all the classes) and some can be reversible for display purposes (if you change a category to an index, you still want to display the category name later on, or if you normalize your image, you need to undo that to display it). `DataSource` calls the potential `setup` function of its `Transform` at initialization and it has a `decode` method that will reverse the transforms that can be reversed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class DataSource():\n", " def __init__(self, items, tfms=None, filters=None, **tfm_kwargs):\n", " if filters is None: filters = [range(len(items))]\n", " if isinstance(filters[0][0], bool): filters = [mask2idxs(filt) for filt in filters]\n", " self.items,self.filters,self.tfms = listify(items),listify(filters),[]\n", " self.tfm_kwargs = tfm_kwargs\n", " tfms = order_sorted(tfms)\n", " for tfm in tfms:\n", " getattr(tfm, 'setup', noop)(self)\n", " self.tfms.append(tfm)\n", " \n", " def transformed(self, tfms, **tfm_kwargs):\n", " tfms = listify(tfms)\n", " tfm_kwargs = {**self.tfm_kwargs, **tfm_kwargs}\n", " return self.__class__(self.items, self.tfms + tfms, self.filters, **tfm_kwargs)\n", " \n", " def __len__(self): return len(self.filters)\n", " def len(self, filt=0): return len(self.filters[filt])\n", " def __getitem__(self, i): return FilteredList(self, i)\n", "\n", " def sublist(self, filt):\n", " return [self.get(j,filt) for j in range(self.len(filt))]\n", "\n", " def get(self, idx, filt=0):\n", " if hasattr(idx,'__len__') and getattr(idx,'ndim',1):\n", " # rank>0 collection\n", " if isinstance(idx[0],bool):\n", " assert len(idx)==self.len(filt) # bool mask\n", " return [self.get(i,filt) for i,m in enumerate(idx) if m]\n", " return [self.get(i,filt) for i in idx] # index list\n", " if self.filters: idx = self.filters[filt][idx]\n", " res = self.items[idx]\n", " if self.tfms: res = apply_all(res, self.tfms, filt=filt, filter_kwargs=True, **self.tfm_kwargs)\n", " return res\n", " \n", " def decode(self, o, filt=0):\n", " if self.tfms: \n", " return apply_all(o, [getattr(f, 'decode', noop) for f in reversed(self.tfms)], \n", " filt=filt, filter_kwargs=True, **self.tfm_kwargs)\n", "\n", " def __iter__(self):\n", " for i in range_of(self.filters):\n", " yield (self.get(j,i) for j in range(self.len(i)))\n", " \n", " def __eq__(self,b):\n", " if not isinstance(b,DataSource): b = DataSource(b)\n", " if len(b) != len(self): return False\n", " for i in range_of(self.filters):\n", " if b.len(i) != self.len(i): return False\n", " return all(self.get(j,i)==b.get(j,i) for j in range_of(self.filters[i]))\n", "\n", " def __repr__(self):\n", " res = f'{self.__class__.__name__}\\n'\n", " for i,o in enumerate(self):\n", " l = self.len(i)\n", " res += f'{i}: ({l} items) ['\n", " res += ','.join(itertools.islice(map(str,o), 10))\n", " if l>10: res += '...'\n", " res += ']\\n'\n", " return res\n", " \n", " @property\n", " def train(self): return self[0]\n", " @property\n", " def valid(self): return self[1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `FilteredList` is a convenience access to one dataset of a `DataSource`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class FilteredList:\n", " def __init__(self, dsrc, filt): self.dsrc,self.filt = dsrc,filt\n", " def __getitem__(self,i): return self.dsrc.get(i,self.filt)\n", " def __len__(self): return self.dsrc.len(self.filt)\n", " \n", " def __iter__(self):\n", " return (self.dsrc.get(j,self.filt) for j in range_of(self))\n", " \n", " def __repr__(self):\n", " res = f'({len(self)} items) ['\n", " res += ','.join(itertools.islice(map(str,self), 10))\n", " if len(self)>10: res += '...'\n", " res += ']\\n'\n", " return res\n", " \n", " def decode(self, o): return self.dsrc.decode(o, self.filt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Tests" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#Indexing\n", "dsrc = DataSource(range(5))\n", "test_eq(dsrc,[0,1,2,3,4])\n", "test_eq(dsrc.sublist(0),[0,1,2,3,4])\n", "test_ne(dsrc,[0,1,2,3,5])\n", "test_eq(dsrc.get(2),2)\n", "test_eq(dsrc.get([1,2]),[1,2])\n", "test_eq(dsrc.get([True,False,False,True,False]),[0,3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#filters can be indices or boolean masks\n", "dsrc = DataSource(range(5), filters=[[0,2], [1,3,4]])\n", "test_eq(list(dsrc[0]),[0,2])\n", "test_eq(list(dsrc[1]),[1,3,4])\n", "#Subsets don't have to be disjoints\n", "dsrc = DataSource(range(5), filters=[[False,True,True,False,True], [True,False,False,True,True]])\n", "test_eq(list(dsrc[0]),[1,2,4])\n", "test_eq(list(dsrc[1]),[0,3,4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#Base transform\n", "dsrc = DataSource(range(5), lambda x:x*2)\n", "test_eq(dsrc,[0,2,4,6,8])\n", "test_eq(dsrc.sublist(0),[0,2,4,6,8])\n", "test_ne(dsrc,[1,2,4,6,8])\n", "test_eq(dsrc.get(2), 4)\n", "test_eq(dsrc.get([1,2]), [2,4])\n", "test_eq(dsrc.get([True,False,False,True,False]), [0,6])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "#Different transforms for the two subsets\n", "dsrc = DataSource(range(5), lambda x,filt:x if filt == 0 else x*2, [[1,2],[0,3,4]])\n", "test_eq(dsrc.sublist(0),[1,2])\n", "test_eq(dsrc.sublist(1),[0,6,8])\n", "test_eq(dsrc.get(2,1), 8)\n", "test_eq(dsrc.get([1,2], 1), [6,8])\n", "test_eq(dsrc.get([False,True], 0), [2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "def add(x, a=1): return x+a\n", "def multiply(x, a=2): return x*a\n", "def square(x): return x**2\n", "\n", "#Tfms are ordered by their `_order` ket when applied\n", "#Test _order\n", "square._order = 0\n", "multiply._order = 1\n", "add._order = 2\n", "dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]])\n", "test_eq(dsrc.get(2), ((2**2) * 2) + 1)\n", "\n", "#Kwargs are passed to tfms when they can be\n", "#Test kwargs\n", "dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]], a=3)\n", "test_eq(dsrc.get(2), ((2**2) * 3) + 3)\n", "\n", "#Test decode\n", "def add_undo(x, a=1): return x-a\n", "def multiply_undo(x, a=2): return x/a\n", "add.decode = add_undo\n", "multiply.decode = multiply_undo\n", "dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]])\n", "test_eq(dsrc.decode(9), (9-1)/2)\n", "dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]], a=3)\n", "test_eq(dsrc.decode(9), (9-3)/3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "dsrc = DataSource(range(5), lambda x,filt:x if filt == 0 else x*2, [[1,2],[0,3,4]])\n", "fl = dsrc[1]\n", "test_eq(list(fl),[0,6,8])\n", "test_eq(fl[2], 8)\n", "test_eq(fl[[1,2]], [6,8])\n", "test_eq(fl[[False,True,True]], [6,8])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the base class to define a transform when we want something more complex than a function. The `_order` helps sort the transforms before applying them. `setup` is a preparation step to get the state ready on the `data` (which is a `DataSource`). `__call__` is the main function that applies the transform to `o` and `decode` does the reverse operation.\n", "\n", "**NB: You should only implement `decode` if your transform needs to be reversed for display purposes.** For instance we want to reverse the operation *class to index*, but we don't want to reverse the operation *open this image*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Transform():\n", " _order = 0\n", " def setup(self, dsrc): return # 1-time setup\n", " def __call__(self,o): return o # transform\n", " def decode(self,o): return o # reverse transform for display" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On top of this, a tranform that is applied to define a base object can have a `show` method: for instance the transform that opens an `Image` has a show method. When trying to display objects, the API will decode it and grab the first transform that provides a `show` method (this can be overriden by passing a custom `show` but we'll see that later).\n", "\n", "The next transform is a bit more complex and is responsible for converting a single item to xs/ys." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_show_func(tfms):\n", " for t in reversed(tfms):\n", " if hasattr(t, 'show') and t.show is not None: return t.show\n", " return None\n", "\n", "def show_xs(xs, shows, ax=None, **kwargs):\n", " for x,show in zip(xs,shows):\n", " # can pass func or obj with a `show` method\n", " show = getattr(show, 'show', show)\n", " ax = feed_kwargs(show, x, ax=ax, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TupleTransform():\n", " def __init__(self, *tfms): self.tfms = [order_sorted(tfm) for tfm in listify(tfms)]\n", " \n", " def __call__(self, o, filt=0, **kwargs): \n", " return [apply_all(o, tfm, filt=filt, filter_kwargs=True, **kwargs) for tfm in self.tfms]\n", " \n", " def decode(self, o, filt=0, **kwargs): \n", " return [apply_all(x, [getattr(f, 'decode', noop) for f in reversed(tfm)], filt=filt, \n", " filter_kwargs=True, **kwargs) \n", " for x,tfm in zip(o,self.tfms)]\n", " \n", " def setup(self, dsrc):\n", " old_tfms = getattr(dsrc, 'tfms', []).copy()\n", " for tfm in self.tfms:\n", " for t in tfm:\n", " getattr(t, 'setup', noop)(dsrc)\n", " dsrc.tfms.append(t)\n", " dsrc.tfms = old_tfms.copy()\n", "\n", " def show(self, o, shows=None, **kwargs):\n", " shows = shows or [None]*len(self.tfms)\n", " shows = [ifnone(show, _get_show_func(tfm)) for tfm,show in zip(self.tfms,shows)]\n", " show_xs(o, shows, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Helper functions to create blocks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Get image files" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))\n", "\n", "def get_image_files(path, include=None, **kwargs):\n", " \"Get image files in `path` recursively.\"\n", " return get_files(path, extensions=image_extensions, recurse=True, include=include)\n", "\n", "def image_getter(suf='', **kwargs):\n", " def _inner(o, **kwargs):\n", " return get_image_files(o/suf, **kwargs)\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "path = untar_data(URLs.MNIST_TINY)\n", "test_eq(len(get_image_files(path)),1428)\n", "test_eq(len(get_image_files(path/'train')),709)\n", "test_eq(len(get_image_files(path, include='train')),709)\n", "test_eq(len(get_image_files(path, include=['train','valid'])),1408)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "path = untar_data(URLs.MNIST_TINY)\n", "test_eq(len(image_getter()(path)),1428)\n", "test_eq(len(image_getter('train')(path)),709)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convention: a function that has the name of a verb and ends with *er* returns a function (to get transforms directly or for use in the high level API belox)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def random_splitter(valid_pct=0.2, seed=None, **kwargs):\n", " \"Split `items` between train/val with `valid_pct` randomly.\"\n", " def _inner(o, **kwargs):\n", " if seed is not None: torch.manual_seed(seed)\n", " rand_idx = torch.randperm(len(o))\n", " cut = int(valid_pct * len(o))\n", " return rand_idx[cut:],rand_idx[:cut]\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test\n", "trn,val = random_splitter(seed=42)([0,1,2,3,4,5])\n", "test_equal(trn, tensor([3, 2, 4, 1, 5]))\n", "test_equal(val, tensor([0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _grandparent_mask(items, name):\n", " return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]\n", "\n", "def grandparent_splitter(train_name='train', valid_name='valid', **kwargs):\n", " \"Split `items` from the grand parent folder names (`train_name` and `valid_name`).\"\n", " def _inner(o, **kwargs):\n", " return _grandparent_mask(o, train_name),_grandparent_mask(o, valid_name)\n", " return _inner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_TINY)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test\n", "#With string filenames\n", "path = untar_data(URLs.MNIST_TINY)\n", "items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png', \n", " path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png', \n", " path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',\n", " path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']\n", "trn,val = grandparent_splitter()(items)\n", "test_eq(trn,[True,False,False,True,True,False,True,False])\n", "test_eq(val,[False,True,True,False,False,True,False,True])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def parent_label(o, **kwargs):\n", " \"Label `item` with the parent folder name.\"\n", " return o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1]\n", "\n", "def re_labeller(pat):\n", " \"Label `item` with regex `pat`.\"\n", " pat = re.compile(pat)\n", " def _inner(o, **kwargs):\n", " res = pat.search(str(o))\n", " assert res,f'Failed to find \"{pat}\" in \"{o}\"'\n", " return res.group(1)\n", " return _inner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pets DataSource" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's grab the Pets dataset first. Our `DataSource` will contain all the image files as items, and we'll randomly select two filters with 80% and 20% of the data. To get our xs, we need to apply a `Transform` that opens the image in the filenames. We'll call it `Imagify`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def show_image(im, ax=None, figsize=None, **kwargs):\n", " \"Show a PIL image on `ax`.\"\n", " if ax is None: _,ax = plt.subplots(figsize=figsize)\n", " if isinstance(im,Tensor) and im.shape[0]<5: im=im.permute(1,2,0)\n", " ax.imshow(im, **kwargs)\n", " ax.axis('off')\n", " return ax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Imagify(Transform):\n", " def __init__(self, f=PIL.Image.open, cmap=None, alpha=1.): self.f,self.cmap,self.alpha = f,cmap,alpha\n", " def __call__(self, fn): return PIL.Image.open(fn)\n", " def show(self, im, ax=None, figsize=None, cmap=None, alpha=None):\n", " cmap = ifnone(cmap,self.cmap)\n", " alpha = ifnone(alpha,self.alpha)\n", " return show_image(im, ax, figsize=figsize, cmap=cmap, alpha=alpha)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get our ys, we'll need to apply the re pattern (function `re_labeller` from before) and a transform that creates the list of classes, then maps categories to their indices. We call that transform `Categorize`. It needs a `setup` to create the classes, and it's reversible so we implement `decode`. This is also a base a transform so we implement its `show` method.\n", "\n", "Since it needs to run after the labelling transform (here our re pattern labeller) we give it an `_order` of 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Categorize(Transform):\n", " _order=1\n", " def __init__(self): self.vocab = None\n", " def __call__(self,o): return self.o2i[o]\n", " def decode(self, o): return self.vocab[o]\n", " def show(self, o, ax=None): \n", " if ax is None: print(o)\n", " else: ax.set_title(o)\n", " \n", " def setup(self, dsrc):\n", " if self.vocab is not None: return\n", " vals = [o for o in dsrc.train]\n", " self.vocab,self.o2i = uniqueify(vals, sort=True, bidir=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can create a `DataSource` that contains our dataset. We grab all the image files, split them randomly and build a `TupleTransform` from open an image / labelling + categorizing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.PETS)/\"images\"\n", "items = get_image_files(source)\n", "split_idx = random_splitter()(items)\n", "xt = Imagify()\n", "yt = Categorize()\n", "labeller = re_labeller(pat = r'/([^/]+)_\\d+.jpg$')\n", "tfm = TupleTransform(xt,[labeller,yt])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets = DataSource(items, tfm, split_idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To access an element we need to specify index/filter (the latter defaults to 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xy = pets.get(0,0)\n", "xy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can decode an element for display purposes!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xy = pets.decode((xy), 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `show_xs` function is there to combine the show methods of the base transforms to display x and y together. We can either pass a transform that has a show method or a custom list of `show` methods." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_xs(xy, (xt, yt))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's monkey-patch a show method to `DataSource` to do this automatically for us. The `TupleTransform` will use `show_xs` by default, but can either pass a custom `show_func`, or also use the kwargs to pass along a custom list of show methods (set `None` for the ones you don't want to override)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _dsrc_show(self, o, filt=0, show_func=None, **kwargs):\n", " o = self.decode(o, filt)\n", " if show_func is None: show_func=_get_show_func(self.tfms)\n", " show_func(o, **kwargs)\n", " \n", "DataSource.show = _dsrc_show\n", "\n", "def _fl_show(self, o, show_func=None, **kwargs):\n", " o = self.decode(o)\n", " if show_func is None: show_func=_get_show_func(self.dsrc.tfms)\n", " show_func(o, **kwargs)\n", "\n", "FilteredList.show = _fl_show" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets.show(pets.get(0,0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Image Transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we can batch our images, we'll need to apply some basic image transformations: conerting to RGB, making them all the same size and also converting them to tensors. We have to get prepared for different kind of targets: sometimes the target won't be applied the transform, but sometimes it will and in different ways. We support images, segmentation masks, points or bounding boxes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "TfmY = Enum('TfmY', 'Mask Image Point Bbox No')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `ImageTransform` class seems ugly but it just dispatches the `apply` or `decode` function properly between `x` and `y`, and allow different implementations of each function. This will be very important for data augmentation in the next notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ImageTransform():\n", " \"Basic class for image transforms.\"\n", " _order,_data_aug = 10,False\n", " \n", " def randomize(self): pass\n", " \n", " def __call__(self, o, filt=0, **kwargs):\n", " if self._data_aug and filt != 0: return o\n", " x,*y = o\n", " self.x,self.filt = x,filt # Saves the x in case it's needed in the apply for y and filt\n", " self.randomize() # Ensures we have the same state for x and y\n", " return (self.apply(x),) + tuple(self.apply_y(y_, **kwargs) for y_ in y)\n", " \n", " def decode(self, o, filt=0, **kwargs):\n", " if self._data_aug and filt != 0: return o\n", " x,*y = o\n", " self.x,self.filt = x,filt\n", " return (self.unapply(x),) + tuple(self.unapply_y(y_, **kwargs) for y_ in y)\n", "\n", " def _tfm_name(self, t, is_decode=False):\n", " return ('unapply_' if is_decode else 'apply_') + t.name.lower()\n", "\n", " def apply_no(self,y): return y\n", " def apply_image(self, y): return self.apply(y)\n", " def apply_mask(self, y): return self.apply_image(y)\n", " def apply_point(self, y): return y\n", " def apply_bbox(self, y): return self.apply_point(y)\n", "\n", " def apply(self, x): return x\n", " def apply_y(self, y, tfm_y=TfmY.No):\n", " return getattr(self, self._tfm_name(tfm_y))(y)\n", " \n", " def unapply_no(self,y): return y\n", " def unapply_image(self, y): return self.unapply(y)\n", " def unapply_mask(self, y): return self.unapply_image(y)\n", " def unapply_point(self, y): return y\n", " def unapply_bbox(self, y): return self.unapply_point(y)\n", " \n", " def unapply(self, x): return x\n", " def unapply_y(self, y, tfm_y=TfmY.No):\n", " return getattr(self, self._tfm_name(tfm_y,True))(y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "import random\n", "class FakeTransform(ImageTransform):\n", " def randomize(self): self.a = random.randint(1,10)\n", " def apply(self, x): return x + self.a\n", " def apply_mask(self, x): return x + 5\n", " def apply_point(self, x): return x + 2\n", "\n", "tfm = FakeTransform()\n", "(x,y) = (5,10)\n", "#Basic behavior: x has changed, not y\n", "t1 = tfm((x,y))\n", "assert t1[0]!=x and t1[1]==y, t1\n", "#Check the same random integer was used for x and y when transforming y\n", "t1 = tfm((x,y), tfm_y=TfmY.Image)\n", "test_eq(t1[0] - 5,t1[1] - 10)\n", "#Check mask, point,bbox implementations\n", "test_eq(tfm((x,y), tfm_y=TfmY.Mask) [1],15)\n", "test_eq(tfm((x,y), tfm_y=TfmY.Point)[1],12)\n", "test_eq(tfm((x,y), tfm_y=TfmY.Bbox) [1],12)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our first transform decodes an image to 'RGB'. We can specify different modes for the xs and ys, the default is 'RGB' for x, then mode_x for y if our ys are images, 'L' for y if our ys are segmentation masks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DecodeImg(ImageTransform):\n", " \"Convert regular image to RGB, masks to L mode.\"\n", " def __init__(self, mode_x='RGB', mode_y=None): self.mode_x,self.mode_y = mode_x,mode_y\n", " def apply(self, x): return x.convert(self.mode_x)\n", " def apply_image(self, y): return y.convert(ifnone(self.mode_y,self.mode_x))\n", " def apply_mask(self, y): return y.convert(ifnone(self.mode_y,'L'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our second transform resizes an image, using a given mode. It defaults to bilinear for images and nearest for segmentation masks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ResizeFixed(ImageTransform):\n", " \"Resize image to `size` using `mode_x` (and `mode_y` on targets).\"\n", " _order=15\n", " def __init__(self, size, mode_x=PIL.Image.BILINEAR, mode_y=None):\n", " if isinstance(size,int): size=(size,size)\n", " self.size = (size[1],size[0]) #PIL takes size in the otherway round\n", " self.mode_x,self.mode_y = mode_x,mode_y\n", " \n", " def apply(self, x): return x.resize(self.size, self.mode_x)\n", " def apply_image(self, y): return y.resize(self.size, ifnone(self.mode_y,self.mode_x))\n", " def apply_mask(self, y): return y.resize(self.size, ifnone(self.mode_y,PIL.Image.NEAREST))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The transformation to tensors is done in two steps just in case one wants to apply transforms to byte tensors. The permutation of axes needs to be reversed for display, so we have an `unapply` function (which is what is called by `decode` in an `ImageTransform`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ToByteTensor(ImageTransform):\n", " \"Transform our items to byte tensors.\"\n", " _order=20\n", " def apply(self, x):\n", " res = torch.ByteTensor(torch.ByteStorage.from_buffer(x.tobytes()))\n", " w,h = x.size\n", " return res.view(h,w,-1).permute(2,0,1)\n", " \n", " def unapply(self, x):\n", " return x[0] if x.shape[0] == 1 else x.permute(1,2,0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly we convert our tensors to floats (or ints for segmentation masks) and divides by 255 (can specify a different value and a `div_y`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ToFloatTensor(ImageTransform):\n", " \"Transform our items to float tensors (int in the case of mask).\"\n", " _order=5 #Need to run after CUDA on the GPU\n", " def __init__(self, div_x=255., div_y=None): self.div_x,self.div_y = div_x,div_y\n", " def apply(self, x): return x.float().div_(self.div_x)\n", " def apply_mask(self, x): \n", " return x.long() if self.div_y is None else x.long().div_(self.div_y)\n", " \n", " def unapply(self, x): return torch.clamp(x, 0, 1)\n", " def unapply_mask(self, x): return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's test it's all work properly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets_t = pets.transformed(tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets_t.show(pets_t.get(0,0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DataBunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With transforms to make our images tensors of the same size, we're ready to create batches and dataloaders. We wrap a PyTorch dataloader to add batch transforms. Additional kwargs will be passed along. Those transforms can be decoded like before, for display purposes (like normalization)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class TfmDataLoader():\n", " def __init__(self, dl, tfms=None, **tfm_kwargs):\n", " self.dl,self.tfms,self.tfm_kwargs = dl,order_sorted(tfms),tfm_kwargs\n", " \n", " def __len__(self): return len(self.dl)\n", " def __iter__(self):\n", " for b in self.dl: yield apply_all(b, self.tfms, filter_kwargs=True, **self.tfm_kwargs)\n", " \n", " def one_batch(self): return next(iter(self))\n", " def decode_batch(self): return self.decode(self.one_batch())\n", " def decode(self, o): \n", " return apply_all(o, [getattr(f, 'decode', noop) for f in reversed(self.tfms)],\n", " filter_kwargs=True, **self.tfm_kwargs)\n", " \n", " @property\n", " def dataset(self): return self.dl.dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we add a basic function to create dataloaders from a `DataSource`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from torch.utils.data.dataloader import DataLoader\n", "\n", "def get_dl(dset, bs=64, tfms=None, tfm_kwargs=None, **kwargs):\n", " dl = DataLoader(dset, bs, **kwargs)\n", " return TfmDataLoader(dl, tfms=tfms, **(ifnone(tfm_kwargs,{})))\n", "\n", "def get_dls(dsrc, bs=64, tfms=None, tfm_kwargs=None, **kwargs):\n", " return [get_dl(dsrc[i], bs, shuffle=i==0, tfms=tfms, tfm_kwargs=tfm_kwargs, **kwargs)\n", " for i in range_of(dsrc)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = get_dls(pets_t, tfms=ToFloatTensor())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a convenience function to grab the k-th item in a batch, even if the batch is constituted of lists of tensors." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def grab_item(b,k):\n", " if isinstance(b, (list,tuple)): return [grab_item(o,k) for o in b]\n", " return b[k]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class DataBunch():\n", " \"Basic wrapper around several `DataLoader`s.\"\n", " def __init__(self, *dls): self.dls = dls\n", " def __getitem__(self, i): return self.dls[i]\n", " \n", " @property\n", " def train_dl(self): return self.dls[0]\n", " @property\n", " def valid_dl(self): return self.dls[1]\n", " @property\n", " def train_ds(self): return self.train_dl.dataset\n", " @property\n", " def valid_ds(self): return self.valid_dl.dataset\n", "\n", " def show_batch(self, i=0, items=9, cols=3, figsize=None, show_func=None, **kwargs):\n", " b = self.dls[i].decode_batch()\n", " rows = (items+cols-1) // cols\n", " if figsize is None: figsize = (cols*3, rows*3)\n", " fig,axs = plt.subplots(rows, cols, figsize=figsize)\n", " for k,ax in enumerate(axs.flatten()):\n", " self[i].dataset.show(grab_item(b,k), ax=ax, show_func=show_func, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch(*dls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = data[0].one_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x.shape,x.type(),y.shape,y.type()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally let's monkey-patch a `databunch` function in `DataSource` to quickly create a `databunch`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _dsrc_databunch(self, bs=64, tfms=None, tfm_kwargs=None, **kwargs):\n", " return DataBunch(*get_dls(self, bs=bs, tfms=tfms, tfm_kwargs=tfm_kwargs, **kwargs))\n", "\n", "DataSource.databunch = _dsrc_databunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cuda" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda',0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "from fastai.torch_core import to_device, to_cpu\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Cuda(Transform):\n", " _order = 0\n", " def __init__(self,device): self.device=device\n", " def __call__(self, b, tfm_y=TfmY.No): return to_device(b, self.device)\n", " def decode(self, b): return to_cpu(b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Normalization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see other batch transforms in the next chapter but one that is pretty common is normalization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Normalize(Transform):\n", " _order=99\n", " def __init__(self, mean, std, do_x=True, do_y=False):\n", " self.mean,self.std,self.do_x,self.do_y = mean,std,do_x,do_y\n", " \n", " def __call__(self, b):\n", " x,y = b\n", " if self.do_x: x = self.normalize(x)\n", " if self.do_y: y = self.normalize(y)\n", " return x,y\n", " \n", " def decode(self, b):\n", " x,y = b\n", " if self.do_x: x = self.denorm(x)\n", " if self.do_y: y = self.denorm(y)\n", " return x,y\n", " \n", " def normalize(self, x): return (x - self.mean) / self.std\n", " def denorm(self, x): return x * self.std + self.mean" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean,std = tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda(),tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda()\n", "data = pets_t.databunch(tfms = [Cuda(device), ToFloatTensor(), Normalize(mean,std)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Higher level API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make it easy to use this data block API, we add a high level class. User provides the type of inputs/targets in `types`. Then they implement the three following functions to gather the data:\n", "\n", "- `get_items` takes the source and retun the list of all items\n", "- `split` take the items and returns two (or more) list of indices or boolean masks that explain how to split the data in train and valid (potentially valids) set.\n", "- `label_func` return the corresponding label for on item.\n", "\n", "Then during the intilialization, default transforms for `x`, `y` and the full datasource are collected (they can be overriden by a custom `tfms_x`, `tfms_y` or `tfm_ds`).\n", "\n", "When calling `datasource`, the `source` is fetched by calling `get_source`, which then allows to collect the items (with `get_items`) and the different splits (with `split`). `label_func` is added to the `y` transforms and a `DataSource` can be created, with additional `tfms` passed.\n", "\n", "When calling `databunch`, the `datasource` (created with `ds_tfms`) is converted, with additional batch transforms (in `dl_tfms`).\n", "\n", "An `Item` is just a class containing three attributes:\n", "\n", "- `tfm` default transforms associated to that item\n", "- `tfm_ds` default transforms associated to that item that are applied to the tuple (x,y)\n", "- `tfm_kwargs` default kwargs to pass to all transforms (it will be filtered and only passed to the transforms that accept them). For instance `{'tfm_y': TfmY.Mask}` in a `SegmentMask`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds_tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor()]\n", "dl_tfms = [Cuda(device), ToFloatTensor()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Item(): tfm,tfm_ds,tfm_kwargs = None,None,None\n", "\n", "def resolve_tfms(o, tfmx, tfmy=None):\n", " if o is not None: return o\n", " return [t() for t in listify(tfmx)+listify(tfmy)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DataBlock():\n", " types = (Item,Item)\n", " @staticmethod\n", " def get_items(source): raise NotImplementedError\n", " @staticmethod\n", " def split(items): raise NotImplementedError\n", " @staticmethod\n", " def label_func(item): raise NotImplementedError\n", " \n", " def __init__(self, source, tfms=None, tfms_ds=None):\n", " self.source = source\n", " if tfms is None: tfms = (None,)*len(self.types)\n", " self.tfms = [resolve_tfms(tfm, x.tfm) for tfm,x in zip(tfms,self.types)]\n", " self.tfms_ds = resolve_tfms(tfms_ds, *[getattr(x,\"tfm_ds\") for x in self.types[:2]])\n", " self.tfm_kwargs = {}\n", " for t in self.types: self.tfm_kwargs.update(t.tfm_kwargs or {})\n", " \n", " def datasource(self, tfms=None, **tfm_kwargs):\n", " cls = self.__class__\n", " items = cls.get_items(self.source, self=self)\n", " split_idx = cls.split(items, self=self)\n", " lfs = getattr(cls, 'label_funcs', (noop,cls.label_func))\n", " ttfms = [[partial(lf, self=self)]+listify(tfm) for lf,tfm in zip(lfs,self.tfms)]\n", " ds = DataSource(items, TupleTransform(*ttfms), split_idx)\n", " ds = ds.transformed(self.tfms_ds + listify(tfms), **{**self.tfm_kwargs, **tfm_kwargs})\n", " return ds\n", " \n", " def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):\n", " tfm_kwargs = ifnone(tfm_kwargs, {})\n", " dls = get_dls(self.datasource(tfms=ds_tfms, **tfm_kwargs), bs, tfms=dl_tfms, \n", " tfm_kwargs={**self.tfm_kwargs, **tfm_kwargs}, **kwargs)\n", " return DataBunch(*dls)\n", " \n", " @property\n", " def xt(self): return self.tfms[0][0]\n", " \n", " @property\n", " def yt(self): return self.tfms[1][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are some examples of items:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Image(Item): tfm = Imagify\n", "class Category(Item): tfm = Categorize" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here is an example of use of the API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PetsData(DataBlock):\n", " types = Image,Category\n", " get_items = image_getter()\n", " split = random_splitter()\n", " label_func = re_labeller(pat = r'/([^/]+)_\\d+.jpg$')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.PETS)/\"images\"\n", "dsrc = PetsData(source)\n", "data = dsrc.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms)\n", "data.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "' '.join(dsrc.yt.vocab)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Try different data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### MNIST" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MnistData(DataBlock):\n", " types = (Image, Category)\n", " get_items = get_image_files\n", " split = grandparent_splitter(train_name='training', valid_name='testing')\n", " label_func = parent_label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.MNIST)\n", "data = MnistData(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are seveal ways to get the show display properly our images. First we can pass a custom `show_func` method and change the function that shows the `x`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(shows = (partial(show_image, cmap='gray'), None))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or trust the API to dispatch the kwargs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(cmap='gray')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or we could create a `BlackAndWhiteImage` class that uses the transform `Imagify` with a default `cmap` to gray:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BlackAndWhiteImage(Item):\n", " tfm = partial(Imagify, cmap='gray')\n", "\n", "class MnistData(DataBlock):\n", " types = (BlackAndWhiteImage, Category)\n", " get_items = get_image_files\n", " split = grandparent_splitter(train_name='training', valid_name='testing')\n", " label_func = parent_label" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = MnistData(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Planet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.PLANET_SAMPLE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class MultiCategorize(Transform):\n", " _order=1\n", " def __init__(self): self.vocab = None\n", " def __call__(self,x): return [self.o2i[o] for o in x if o in self.o2i]\n", " def decode(self, o): return [self.vocab[i] for i in o]\n", " def show(self, o, ax=None): \n", " (print if ax is None else ax.set_title)(';'.join(o))\n", " \n", " def setup(self, dsrc):\n", " if self.vocab is not None: return\n", " vals = set()\n", " for c in dsrc.train: vals = vals.union(set(c))\n", " self.vocab,self.o2i = uniqueify(list(vals), sort=True, bidir=True)\n", " \n", "class OneHotEncode(Transform):\n", " _order=10\n", " def setup(self, items):\n", " self.c = None\n", " for tfm in items.tfms:\n", " if isinstance(tfm, MultiCategorize): self.c = len(tfm.vocab)\n", " \n", " def __call__(self, o): return onehot(o, self.c) if self.c is not None else o\n", " def decode(self, o): return [i for i,x in enumerate(o) if x == 1]\n", " \n", "class MultiCategory(Item):\n", " tfm = [MultiCategorize, OneHotEncode]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "tfm = MultiCategorize()\n", "#Even if 'c' is the first class, vocab is sorted for reproducibility\n", "ds = DataSource([['c','a'], ['a','b'], ['b'], []], filters=[[0,1,2,3], []])\n", "tfm.setup(ds)\n", "test_eq(tfm.vocab,['a','b','c'])\n", "test_eq(tfm(['b','a']),[1,0])\n", "test_eq(tfm.decode([2,0]),['c','a'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def get_str_column(df, col_name, prefix='', suffix='', delim=None):\n", " \"Read `col_name` in `df`, optionnally adding `prefix` or `suffix`.\"\n", " values = df[col_name].values.astype(str)\n", " values = np.char.add(np.char.add(prefix, values), suffix)\n", " if delim is not None:\n", " values = np.array(list(csv.reader(values, delimiter=delim)))\n", " return values" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# test\n", "df = pd.DataFrame({'a': ['cat', 'dog', 'car'], 'b': ['a b', 'c d', 'a e']})\n", "test_np_eq(get_str_column(df, 'a'), np.array(['cat', 'dog', 'car']))\n", "test_np_eq(get_str_column(df, 'a', prefix='o'), np.array(['ocat', 'odog', 'ocar']))\n", "test_np_eq(get_str_column(df, 'a', suffix='.png'), np.array(['cat.png', 'dog.png', 'car.png']))\n", "test_np_eq(get_str_column(df, 'b', delim=' '), np.array([['a','b'], ['c','d'], ['a','e']]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PlanetData(DataBlock):\n", " types = Image,MultiCategory\n", " \n", " def get_items(source, self): \n", " df = pd.read_csv(self.source/'labels.csv')\n", " items = get_str_column(df, 'image_name', prefix=f'{self.source}/train/', suffix='.jpg')\n", " labels = get_str_column(df, 'tags', delim=' ')\n", " self.item2label = {i:s for i,s in zip(items,labels)}\n", " return items\n", " \n", " split = random_splitter()\n", " def label_func(item, self): return self.item2label[item]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.PLANET_SAMPLE)\n", "dsrc = PlanetData(source)\n", "data = dsrc.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Camvid" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SegmentMask(Item):\n", " tfm = partial(Imagify, cmap='tab20', alpha=0.5)\n", " tfm_kwargs = {'tfm_y': TfmY.Mask}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CamvidData(DataBlock):\n", " types = Image,SegmentMask\n", " get_items = image_getter('images')\n", " split = random_splitter()\n", " label_func = lambda o,self: self.source/'labels'/f'{o.stem}_P{o.suffix}'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.CAMVID_TINY)\n", "data = CamvidData(source).databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(cmap='tab20')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Biwii" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class PointScaler(Transform):\n", " _order = 5 #Run before we apply any ImageTransform\n", " def __init__(self, do_scale=True, y_first=False): \n", " self.do_scale,self.y_first = do_scale,y_first\n", " \n", " def __call__(self, o, tfm_y=TfmY.No):\n", " x,y = o\n", " if not isinstance(y, torch.Tensor): y = tensor(y)\n", " y = y.view(-1, 2).float()\n", " if not self.y_first: y = y.flip(1)\n", " if self.do_scale: y = y * 2/tensor(list(x.size)).float() - 1\n", " return (x,y)\n", " \n", " def decode(self, o, tfm_y=TfmY.No):\n", " x,y = o\n", " y = y.flip(1)\n", " y = (y+1) * tensor([x.shape[:2]]).float()/2\n", " return (x,y)\n", "\n", "class PointShow(Transform):\n", " def show(self, x, ax=None):\n", " params = {'s': 10, 'marker': '.', 'c': 'r'}\n", " ax.scatter(x[:, 1], x[:, 0], **params)\n", "\n", "class Points(Item):\n", " tfm = PointShow\n", " tfm_ds = PointScaler\n", " tfm_kwargs = {'tfm_y': TfmY.Point}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BiwiData(DataBlock):\n", " types = Image,Points\n", " def __init__(self, source, *args, **kwargs):\n", " super().__init__(source, *args, **kwargs)\n", " self.fn2ctr = pickle.load(open(source/'centers.pkl', 'rb'))\n", " \n", " get_items = image_getter('images')\n", " split = random_splitter()\n", " label_func = lambda o,self: self.fn2ctr[o.name]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dblk = BiwiData(untar_data(URLs.BIWI_SAMPLE))\n", "data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Coco" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "from fastai.vision.data import get_annotations\n", "from matplotlib import patches, patheffects\n", "\n", "def _draw_outline(o, lw):\n", " o.set_path_effects([patheffects.Stroke(linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def _draw_rect(ax, b, color='white', text=None, text_size=14, hw=True, rev=False):\n", " lx,ly,w,h = b\n", " if rev: lx,ly,w,h = ly,lx,h,w\n", " if not hw: w,h = w-lx,h-ly\n", " patch = ax.add_patch(patches.Rectangle((lx,ly), w, h, fill=False, edgecolor=color, lw=2))\n", " _draw_outline(patch, 4)\n", " if text is not None:\n", " patch = ax.text(lx,ly, text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')\n", " _draw_outline(patch,1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class BBoxScaler(PointScaler):\n", " def __call__(self, o, tfm_y=TfmY.Bbox): \n", " x,y = o\n", " return x, (super().__call__((x,y[0]))[1].view(-1,4),y[1])\n", " \n", " def decode(self, o, tfm_y=TfmY.Bbox): \n", " x,y = o\n", " _,bbox = super().decode((x,y[0].view(-1,2)))\n", " return x, (bbox.view(-1,4),y[1])\n", " \n", "class BBoxEncoder(Transform):\n", " _order=1\n", " def __init__(self): self.vocab = None\n", " \n", " def __call__(self,o):\n", " x,y = o\n", " return (x,[self.otoi[o_] for o_ in y if o_ in self.otoi])\n", " \n", " def decode(self, o):\n", " x,y = o\n", " return x, [self.vocab[i] for i in y]\n", " \n", " def setup(self, dsrc):\n", " if self.vocab is not None: return\n", " vals = set()\n", " for bb,c in dsrc.train: vals = vals.union(set(c))\n", " self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')\n", " \n", " def show(self, x, ax):\n", " bbox,label = x\n", " for b,l in zip(bbox, label): \n", " if l != '#bg': _draw_rect(ax, b, hw=False, rev=True, text=l)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class BBox(Item): tfm,tfm_ds,tfm_kwargs = BBoxEncoder,BBoxScaler,{'tfm_y': TfmY.Bbox}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def bb_pad_collate(samples, pad_idx=0):\n", " max_len = max([len(s[1][1]) for s in samples])\n", " bboxes = torch.zeros(len(samples), max_len, 4)\n", " labels = torch.zeros(len(samples), max_len).long() + pad_idx\n", " imgs = []\n", " for i,s in enumerate(samples):\n", " imgs.append(s[0][None])\n", " bbs, lbls = s[1]\n", " if not (bbs.nelement() == 0):\n", " bboxes[i,-len(lbls):] = bbs\n", " labels[i,-len(lbls):] = tensor(lbls)\n", " return torch.cat(imgs,0), (bboxes,labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CocoData(DataBlock):\n", " types = Image,BBox\n", " def __init__(self, source, *args, **kwargs):\n", " super().__init__(source, *args, **kwargs)\n", " images, lbl_bbox = get_annotations(source/'train.json')\n", " self.img2bbox = dict(zip(images, lbl_bbox))\n", " \n", " get_items = image_getter('train')\n", " split = random_splitter()\n", " label_func = lambda o,self: self.img2bbox[o.name]\n", " \n", " def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):\n", " return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs,\n", " collate_fn=bb_pad_collate, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.COCO_TINY)\n", "dblk = CocoData(source)\n", "data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use of the low-level API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also use `DataSource` directly with just one transform that does everything, without using the blocks. You will have to provide your show method if you want to use `show_batch` however (no need to decode if you do everything in one transform)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def size_f(x): return tensor(x.size).float()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.COCO_TINY)\n", "fns, lbl_bbox = get_annotations(path/'train.json')\n", "img2bbox = dict(zip(fns, lbl_bbox))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CocoTransform(Transform):\n", " def __init__(self): self.vocab = None\n", " def setup(self, data):\n", " if self.vocab is not None: return\n", " vals = set()\n", " for c in data.train: vals = vals.union(set(img2bbox[c.name][1]))\n", " self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')\n", " \n", " def __call__(self, fn):\n", " img = PIL.Image.open(fn)\n", " bbox,lbl = img2bbox[fn.name]\n", " #flip and rescale to -1,1\n", " bbox = tensor(bbox).view(-1,2).flip(1) * 2/size_f(img) - 1\n", " lbl = [self.otoi[l] for l in lbl if l in self.otoi]\n", " return (img, [bbox.view(-1,4), lbl])\n", " \n", " def show(self, o, ax):\n", " img,(bbox,lbl) = o\n", " show_image(img, ax)\n", " lbl = [self.vocab[l] for l in lbl if l != 0] #Unpad and decode\n", " bbox = bbox[-len(lbl):,] #Unpad\n", " bbox = (bbox.view(-1,2) + 1) * tensor(img.shape[:2]).float() / 2\n", " bbox = bbox.flip(1).view(-1,4)\n", " for b,l in zip(bbox, lbl): _draw_rect(ax, b, hw=False, rev=True, text=l)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fnames = get_image_files(path)\n", "splits = random_splitter()(fnames)\n", "ds = DataSource(fnames, tfms=CocoTransform(), filters=splits, tfm_y=TfmY.Bbox)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds = ds.transformed(tfms=ds_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch(*get_dls(ds, 16, collate_fn=bb_pad_collate, tfms=dl_tfms))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Coco data block lower level" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CocoData2(DataBlock):\n", " types = Image,Item,MultiCategory\n", " def __init__(self, source, *args, **kwargs):\n", " super().__init__(source, *args, **kwargs)\n", " images, lbl_bbox = get_annotations(source/'train.json')\n", " self.img2bbox = dict(zip(images, lbl_bbox))\n", " \n", " get_items = image_getter('train')\n", " split = random_splitter()\n", " label_funcs = (noop, lambda o,self: self.img2bbox[o.name][0],\n", " lambda o,self: self.img2bbox[o.name][1])\n", " \n", " def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):\n", " return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "source = untar_data(URLs.COCO_TINY)\n", "dblk = CocoData2(source)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dsrc = dblk.datasource(tfms=ds_tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.one_batch(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "! python notebook2script.py \"200_datablock_config.ipynb\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }