[ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp data.block" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.torch_basics import *\n", "from fastai.data.core import *\n", "from fastai.data.load import *\n", "from fastai.data.external import *\n", "from fastai.data.transforms import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data block\n", "\n", "> High level API to quickly get your data in a `DataLoaders`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TransformBlock -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TransformBlock():\n", " \"A basic wrapper that links defaults transforms for the data block API\"\n", " def __init__(self, type_tfms=None, item_tfms=None, batch_tfms=None, dl_type=None, dls_kwargs=None):\n", " self.type_tfms = L(type_tfms)\n", " self.item_tfms = ToTensor + L(item_tfms)\n", " self.batch_tfms = L(batch_tfms)\n", " self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def CategoryBlock(vocab=None, sort=True, add_na=False):\n", " \"`TransformBlock` for single-label categorical targets\"\n", " return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MultiCategoryBlock(encoded=False, vocab=None, add_na=False):\n", " \"`TransformBlock` for multi-label categorical targets\"\n", " tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]\n", " return TransformBlock(type_tfms=tfm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RegressionBlock(n_out=None):\n", " \"`TransformBlock` for float targets\"\n", " return TransformBlock(type_tfms=RegressionSetup(c=n_out))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## General API" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from inspect import isfunction,ismethod" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _merge_grouper(o):\n", " if isinstance(o, LambdaType): return id(o)\n", " elif isinstance(o, type): return o\n", " elif (isfunction(o) or ismethod(o)): return o.__qualname__\n", " return o.__class__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _merge_tfms(*tfms):\n", " \"Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating\"\n", " g = groupby(concat(*tfms), _merge_grouper)\n", " return L(v[-1] for k,v in g.items()).map(instantiate)\n", "\n", "def _zip(x): return L(x).zip()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For example, so not exported\n", "from fastai.vision.core import *\n", "from fastai.vision.data import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "tfms = _merge_tfms([Categorize, MultiCategorize, Categorize(['dog', 'cat'])], Categorize(['a', 'b']))\n", "#If there are several instantiated versions, the last one is kept.\n", "test_eq(len(tfms), 2)\n", "test_eq(tfms[1].__class__, MultiCategorize)\n", "test_eq(tfms[0].__class__, Categorize)\n", "test_eq(tfms[0].vocab, ['a', 'b'])\n", "\n", "tfms = _merge_tfms([PILImage.create, PILImage.show])\n", "#Check methods are properly separated\n", "test_eq(len(tfms), 2)\n", "tfms = _merge_tfms([show_image, set_trace])\n", "#Check functions are properly separated\n", "test_eq(len(tfms), 2)\n", "\n", "_f = lambda x: 0\n", "test_eq(len(_merge_tfms([_f,lambda x: 1])), 2)\n", "test_eq(len(_merge_tfms([_f,_f])), 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@docs\n", "@funcs_kwargs\n", "class DataBlock():\n", " \"Generic container to quickly build `Datasets` and `DataLoaders`\"\n", " get_x=get_items=splitter=get_y = None\n", " blocks,dl_type = (TransformBlock,TransformBlock),TfmdDL\n", " _methods = 'get_items splitter get_y get_x'.split()\n", " _msg = \"If you wanted to compose several transforms in your getter don't forget to wrap them in a `Pipeline`.\"\n", " def __init__(self, blocks=None, dl_type=None, getters=None, n_inp=None, item_tfms=None, batch_tfms=None, **kwargs):\n", " blocks = L(self.blocks if blocks is None else blocks)\n", " blocks = L(b() if callable(b) else b for b in blocks)\n", " self.type_tfms = blocks.attrgot('type_tfms', L())\n", " self.default_item_tfms = _merge_tfms(*blocks.attrgot('item_tfms', L()))\n", " self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))\n", " for b in blocks:\n", " if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type\n", " if dl_type is not None: self.dl_type = dl_type\n", " self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)\n", " self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))\n", "\n", " self.n_inp = ifnone(n_inp, max(1, len(blocks)-1))\n", " self.getters = ifnone(getters, [noop]*len(self.type_tfms))\n", " if self.get_x:\n", " if len(L(self.get_x)) != self.n_inp:\n", " raise ValueError(f'get_x contains {len(L(self.get_x))} functions, but must contain {self.n_inp} (one for each input)\\n{self._msg}')\n", " self.getters[:self.n_inp] = L(self.get_x)\n", " if self.get_y:\n", " n_targs = len(self.getters) - self.n_inp\n", " if len(L(self.get_y)) != n_targs:\n", " raise ValueError(f'get_y contains {len(L(self.get_y))} functions, but must contain {n_targs} (one for each target)\\n{self._msg}')\n", " self.getters[self.n_inp:] = L(self.get_y)\n", "\n", " if kwargs: raise TypeError(f'invalid keyword arguments: {\", \".join(kwargs.keys())}')\n", " self.new(item_tfms, batch_tfms)\n", "\n", " def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(\n", " lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)\n", "\n", " def new(self, item_tfms=None, batch_tfms=None):\n", " self.item_tfms = _merge_tfms(self.default_item_tfms, item_tfms)\n", " self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)\n", " return self\n", "\n", " @classmethod\n", " def from_columns(cls, blocks=None, getters=None, get_items=None, **kwargs):\n", " if getters is None: getters = L(ItemGetter(i) for i in range(2 if blocks is None else len(L(blocks))))\n", " get_items = _zip if get_items is None else compose(get_items, _zip)\n", " return cls(blocks=blocks, getters=getters, get_items=get_items, **kwargs)\n", "\n", " def datasets(self, source, verbose=False):\n", " self.source = source ; pv(f\"Collecting items from {source}\", verbose)\n", " items = (self.get_items or noop)(source) ; pv(f\"Found {len(items)} items\", verbose)\n", " splits = (self.splitter or RandomSplitter())(items)\n", " pv(f\"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}\", verbose)\n", " return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)\n", "\n", " def dataloaders(self, source, path='.', verbose=False, **kwargs):\n", " dsets = self.datasets(source, verbose=verbose)\n", " kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}\n", " return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)\n", "\n", " _docs = dict(new=\"Create a new `DataBlock` with other `item_tfms` and `batch_tfms`\",\n", " datasets=\"Create a `Datasets` object from `source`\",\n", " dataloaders=\"Create a `DataLoaders` object from `source`\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To build a `DataBlock` you need to give the library four things: the types of your input/labels, and at least two functions: `get_items` and `splitter`. You may also need to include `get_x` and `get_y` or a more generic list of `getters` that are applied to the results of `get_items`.\n", "\n", "Once those are provided, you automatically get a `Datasets` or a `DataLoaders`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "


\n", "\n", "> DataBlock.datasets(**`source`**, **`verbose`**=*`False`*)\n", "\n", "Create a [`Datasets`](/data.core#Datasets) object from `source`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(DataBlock.datasets)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "


\n", "\n", "> DataBlock.dataloaders(**`source`**, **`path`**=*`'.'`*, **`verbose`**=*`False`*, **`bs`**=*`64`*, **`shuffle`**=*`False`*, **`num_workers`**=*`None`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`drop_last`**=*`False`*, **`indexed`**=*`None`*, **`n`**=*`None`*, **`device`**=*`None`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*)\n", "\n", "Create a [`DataLoaders`](/data.core#DataLoaders) object from `source`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide_input\n", "dblock = DataBlock()\n", "show_doc(dblock.dataloaders, name=\"DataBlock.dataloaders\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can create a `DataBlock` by passing functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist = DataBlock(blocks = (ImageBlock(cls=PILImageBW),CategoryBlock),\n", " get_items = get_image_files,\n", " splitter = GrandparentSplitter(),\n", " get_y = parent_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each type comes with default transforms that will be applied\n", "- at the base level to create items in a tuple (usually input,target) from the base elements (like filenames)\n", "- at the item level of the datasets\n", "- at the batch level\n", "\n", "They are called respectively type transforms, item transforms, batch transforms. In the case of MNIST, the type transforms are the method to create a `PILImageBW` (for the input) and the `Categorize` transform (for the target), the item transform is `ToTensor` and the batch transforms are `Cuda` and `IntToFloatTensor`. You can add any other transforms by passing them in `DataBlock.datasets` or `DataBlock.dataloaders`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(mnist.type_tfms[0], [PILImageBW.create])\n", "test_eq(mnist.type_tfms[1].map(type), [Categorize])\n", "test_eq(mnist.default_item_tfms.map(type), [ToTensor])\n", "test_eq(mnist.default_batch_tfms.map(type), [IntToFloatTensor])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAACLCAYAAABBVeZmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAE3klEQVR4nO2dTyh0axzHf4/srpAkpRQJCTvkioQsbBSykO4srGSDsvRnIVvq9a+s2Chyy07dEq5YKSlFKVnpXlLcu7DAuav3NL/nvWbMOzPnnJnv91PqfJs508On3/l5znPOGeM4jhAMMvweAPEOygaCsoGgbCAoGwjKBoKygYCSbYz51/p5N8Z883tcXpHp9wC8xHGcrO/bxphfROQvEdn2b0TeAlXZFn0i8reI/On3QLwCWXZIRDYcoPPFBuh3dTHGFIvIrYiUOY5z6/d4vAK1sn8TkWMk0SLYstf9HoTXwB3GjTG/isgfIlLoOM4/fo/HSxArOyQiv6OJFgGsbGQQKxsWygaCsoGgbCAoG4hoq178Vz31MJ+9wMoGgrKBoGwgKBsIygaCsoGgbCAoGwjKBoKygaBsICgbCMoGgrKBoGwgKBsIygaCsoGgbCAoGwjKBoKygaBsICgbCMoGgrKBSNmH3r29val8fn6ucn19vcoFBQXu9sTEhHrt+vpa5bW1NZXb2tpU7urqim2wFrW1tSq3t7e72xkZyas/VjYQlA0EZQMR7QE6gbll9/n5WeWRkRGVNzc3vRxOQnl4eHC38/Ly4v043rJLKBsKygYisPPs19dXlZubm1W+vLyM6fNmZmbc7c7OzojvXVxcVPn09FTl29v4nm+bn5+vcmamNxpY2UBQNhCUDURgerbdo4eHh1WO1qON0dPLoaEhlcfGxtztrKwsiURDQ4PKLy8vKu/v76vc29sb8fNsenp6VM7Ozo5p/5+FlQ0EZQNB2UAE5tz43d2dyqWlpTHt39TUpPLR0VHcY/rOwcGByoODgyrf399H3L+oqEjlw8NDlUtKSn5+cD/Cc+OEsqGgbCACM8/Ozc1VORQKqby+rr+GK3zeLCIyOzubnIGJyO7ursrRerTN1taWygnu0V+GlQ0EZQNB2UAEZp5t8/HxobJ9nbh9fXU8a8L232B6elrlubm5iO+3mZycVHlqakrlZF4bLpxnExHKhiKwh/Fk8v7+rrJ9arWjoyPi/vZyaktLi8r2EqjH8DBOKBsKygYCsmff3NyoXFFREdP+ZWVlKtu3/PoMezahbCgoG4jALHEmm8fHR3d7ZWUlpn2Li4tVPj4+TsiYvIaVDQRlA0HZQKTtPNteEh0YGHC3d3Z2Iu5rXzZk9+jCwsI4R5dUOM8mlA0FZQORNvNs+zKm8fFxlaP16XD29vZUDniP/jKsbCAoGwjKBiJtevb29rbKS0tLn77XPte9sLCgcqy3C6cKrGwgKBsIygYiZXu23aPDz31Ho66uTuXu7u6EjCnosLKBoGwgKBuIlFnPvrq6UrmxsVFl+5GTNqOjo+62fUut/YiPFIfr2YSyoaBsIALbs+1HTre2tqr89PQUcf+amhqVwx9JmWY92oY9m1A2FJQNRGB6dvhXFIqIVFdXqxx+r9b/UVlZqfLJyYnKOTk5cYwupWDPJpQNRWCWOJeXl1WOdtiuqqpSeXV1VWWgw/aXYWUDQdlAUDYQvvXss7Mzlefn52Pav6+vT2X723/Ij7CygaBsICgbCN96dnl5ucr2oy0uLi5U3tjYULm/vz85A0tjWNlAUDYQlA1EYJY4ScLgEiehbCgoG4ho8+xPj/8k9WBlA0HZQFA2EJQNBGUDQdlA/AdDRSN/yb0OvgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n", "test_eq(dsets.vocab, ['3', '7'])\n", "x,y = dsets.train[0]\n", "test_eq(x.size,(28,28))\n", "show_at(dsets.train, 0, cmap='Greys', figsize=(2,2));" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fail(lambda: DataBlock(wrong_kwarg=42, wrong_kwarg2='foo'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can pass any number of blocks to `DataBlock`, we can then define what are the input and target blocks by changing `n_inp`. For example, defining `n_inp=2` will consider the first two blocks passed as inputs and the others as targets. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n", " get_y=parent_label)\n", "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n", "test_eq(mnist.n_inp, 2)\n", "test_eq(len(dsets.train[0]), 3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_fail(lambda: DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n", " get_y=[parent_label, noop],\n", " n_inp=2), msg='get_y contains 2 functions, but must contain 1 (one for each output)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n", " n_inp=1,\n", " get_y=[noop, Pipeline([noop, parent_label])])\n", "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n", "test_eq(len(dsets.train[0]), 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Debugging" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _short_repr(x):\n", " if isinstance(x, tuple): return f'({\", \".join([_short_repr(y) for y in x])})'\n", " if isinstance(x, list): return f'[{\", \".join([_short_repr(y) for y in x])}]'\n", " if not isinstance(x, Tensor): return str(x)\n", " if x.numel() <= 20 and x.ndim <=1: return str(x)\n", " return f'{x.__class__.__name__} of size {\"x\".join([str(d) for d in x.shape])}'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(_short_repr(TensorImage(torch.randn(40,56))), 'TensorImage of size 40x56')\n", "test_eq(_short_repr(TensorCategory([1,2,3])), 'TensorCategory([1, 2, 3])')\n", "test_eq(_short_repr((TensorImage(torch.randn(40,56)), TensorImage(torch.randn(32,20)))),\n", " '(TensorImage of size 40x56, TensorImage of size 32x20)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _apply_pipeline(p, x):\n", " print(f\" {p}\\n starting from\\n {_short_repr(x)}\")\n", " for f in p.fs:\n", " name = f.name\n", " try:\n", " x = f(x)\n", " if name != \"noop\": print(f\" applying {name} gives\\n {_short_repr(x)}\")\n", " except Exception as e:\n", " print(f\" applying {name} failed.\")\n", " raise e\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.data.load import _collate_types\n", "\n", "def _find_fail_collate(s):\n", " s = L(*s)\n", " for x in s[0]:\n", " if not isinstance(x, _collate_types): return f\"{type(x).__name__} is not collatable\"\n", " for i in range_of(s[0]):\n", " try: _ = default_collate(s.itemgot(i))\n", " except:\n", " shapes = [getattr(o[i], 'shape', None) for o in s]\n", " return f\"Could not collate the {i}-th members of your tuples because got the following shapes\\n{','.join([str(s) for s in shapes])}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def summary(self: DataBlock, source, bs=4, show_batch=False, **kwargs):\n", " \"Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`.\"\n", " print(f\"Setting-up type transforms pipelines\")\n", " dsets = self.datasets(source, verbose=True)\n", " print(\"\\nBuilding one sample\")\n", " for tl in dsets.train.tls:\n", " _apply_pipeline(tl.tfms, get_first(dsets.train.items))\n", " print(f\"\\nFinal sample: {dsets.train[0]}\\n\\n\")\n", "\n", " dls = self.dataloaders(source, bs=bs, verbose=True)\n", " print(\"\\nBuilding one batch\")\n", " if len([f for f in dls.train.after_item.fs if f.name != 'noop'])!=0:\n", " print(\"Applying item_tfms to the first sample:\")\n", " s = [_apply_pipeline(dls.train.after_item, dsets.train[0])]\n", " print(f\"\\nAdding the next {bs-1} samples\")\n", " s += [dls.train.after_item(dsets.train[i]) for i in range(1, bs)]\n", " else:\n", " print(\"No item_tfms to apply\")\n", " s = [dls.train.after_item(dsets.train[i]) for i in range(bs)]\n", "\n", " if len([f for f in dls.train.before_batch.fs if f.name != 'noop'])!=0:\n", " print(\"\\nApplying before_batch to the list of samples\")\n", " s = _apply_pipeline(dls.train.before_batch, s)\n", " else: print(\"\\nNo before_batch transform to apply\")\n", "\n", " print(\"\\nCollating items in a batch\")\n", " try:\n", " b = dls.train.create_batch(s)\n", " b = retain_types(b, s[0] if is_listy(s) else s)\n", " except Exception as e:\n", " print(\"Error! It's not possible to collate your items in a batch\")\n", " why = _find_fail_collate(s)\n", " print(\"Make sure all parts of your samples are tensors of the same size\" if why is None else why)\n", " raise e\n", "\n", " if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0:\n", " print(\"\\nApplying batch_tfms to the batch built\")\n", " b = to_device(b, dls.device)\n", " b = _apply_pipeline(dls.train.after_batch, b)\n", " else: print(\"\\nNo batch_tfms to apply\")\n", "\n", " if show_batch: dls.show_batch(**kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "


