{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp data.block" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.torch_basics import *\n", "from fastai.data.core import *\n", "from fastai.data.load import *\n", "from fastai.data.external import *\n", "from fastai.data.transforms import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data block\n", "\n", "> High level API to quickly get your data in a `DataLoaders`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TransformBlock -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TransformBlock():\n", " \"A basic wrapper that links defaults transforms for the data block API\"\n", " def __init__(self, type_tfms=None, item_tfms=None, batch_tfms=None, dl_type=None, dls_kwargs=None):\n", " self.type_tfms = L(type_tfms)\n", " self.item_tfms = ToTensor + L(item_tfms)\n", " self.batch_tfms = L(batch_tfms)\n", " self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def CategoryBlock(vocab=None, sort=True, add_na=False):\n", " \"`TransformBlock` for single-label categorical targets\"\n", " return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MultiCategoryBlock(encoded=False, vocab=None, add_na=False):\n", " \"`TransformBlock` for multi-label categorical targets\"\n", " tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]\n", " return TransformBlock(type_tfms=tfm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RegressionBlock(n_out=None):\n", " \"`TransformBlock` for float targets\"\n", " return TransformBlock(type_tfms=RegressionSetup(c=n_out))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## General API" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from inspect import isfunction,ismethod" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _merge_grouper(o):\n", " if isinstance(o, LambdaType): return id(o)\n", " elif isinstance(o, type): return o\n", " elif (isfunction(o) or ismethod(o)): return o.__qualname__\n", " return o.__class__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _merge_tfms(*tfms):\n", " \"Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating\"\n", " g = groupby(concat(*tfms), _merge_grouper)\n", " return L(v[-1] for k,v in g.items()).map(instantiate)\n", "\n", "def _zip(x): return L(x).zip()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For example, so not exported\n", "from fastai.vision.core import *\n", "from fastai.vision.data import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "tfms = _merge_tfms([Categorize, MultiCategorize, Categorize(['dog', 'cat'])], Categorize(['a', 'b']))\n", "#If there are several instantiated versions, the last one is kept.\n", "test_eq(len(tfms), 2)\n", "test_eq(tfms[1].__class__, MultiCategorize)\n", "test_eq(tfms[0].__class__, Categorize)\n", "test_eq(tfms[0].vocab, ['a', 'b'])\n", "\n", "tfms = _merge_tfms([PILImage.create, PILImage.show])\n", "#Check methods are properly separated\n", "test_eq(len(tfms), 2)\n", "tfms = _merge_tfms([show_image, set_trace])\n", "#Check functions are properly separated\n", "test_eq(len(tfms), 2)\n", "\n", "_f = lambda x: 0\n", "test_eq(len(_merge_tfms([_f,lambda x: 1])), 2)\n", "test_eq(len(_merge_tfms([_f,_f])), 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@docs\n", "@funcs_kwargs\n", "class DataBlock():\n", " \"Generic container to quickly build `Datasets` and `DataLoaders`\"\n", " get_x=get_items=splitter=get_y = None\n", " blocks,dl_type = (TransformBlock,TransformBlock),TfmdDL\n", " _methods = 'get_items splitter get_y get_x'.split()\n", " _msg = \"If you wanted to compose several transforms in your getter don't forget to wrap them in a `Pipeline`.\"\n", " def __init__(self, blocks=None, dl_type=None, getters=None, n_inp=None, item_tfms=None, batch_tfms=None, **kwargs):\n", " blocks = L(self.blocks if blocks is None else blocks)\n", " blocks = L(b() if callable(b) else b for b in blocks)\n", " self.type_tfms = blocks.attrgot('type_tfms', L())\n", " self.default_item_tfms = _merge_tfms(*blocks.attrgot('item_tfms', L()))\n", " self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))\n", " for b in blocks:\n", " if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type\n", " if dl_type is not None: self.dl_type = dl_type\n", " self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)\n", " self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))\n", "\n", " self.n_inp = ifnone(n_inp, max(1, len(blocks)-1))\n", " self.getters = ifnone(getters, [noop]*len(self.type_tfms))\n", " if self.get_x:\n", " if len(L(self.get_x)) != self.n_inp:\n", " raise ValueError(f'get_x contains {len(L(self.get_x))} functions, but must contain {self.n_inp} (one for each input)\\n{self._msg}')\n", " self.getters[:self.n_inp] = L(self.get_x)\n", " if self.get_y:\n", " n_targs = len(self.getters) - self.n_inp\n", " if len(L(self.get_y)) != n_targs:\n", " raise ValueError(f'get_y contains {len(L(self.get_y))} functions, but must contain {n_targs} (one for each target)\\n{self._msg}')\n", " self.getters[self.n_inp:] = L(self.get_y)\n", "\n", " if kwargs: raise TypeError(f'invalid keyword arguments: {\", \".join(kwargs.keys())}')\n", " self.new(item_tfms, batch_tfms)\n", "\n", " def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(\n", " lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)\n", "\n", " def new(self, item_tfms=None, batch_tfms=None):\n", " self.item_tfms = _merge_tfms(self.default_item_tfms, item_tfms)\n", " self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)\n", " return self\n", "\n", " @classmethod\n", " def from_columns(cls, blocks=None, getters=None, get_items=None, **kwargs):\n", " if getters is None: getters = L(ItemGetter(i) for i in range(2 if blocks is None else len(L(blocks))))\n", " get_items = _zip if get_items is None else compose(get_items, _zip)\n", " return cls(blocks=blocks, getters=getters, get_items=get_items, **kwargs)\n", "\n", " def datasets(self, source, verbose=False):\n", " self.source = source ; pv(f\"Collecting items from {source}\", verbose)\n", " items = (self.get_items or noop)(source) ; pv(f\"Found {len(items)} items\", verbose)\n", " splits = (self.splitter or RandomSplitter())(items)\n", " pv(f\"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}\", verbose)\n", " return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)\n", "\n", " def dataloaders(self, source, path='.', verbose=False, **kwargs):\n", " dsets = self.datasets(source, verbose=verbose)\n", " kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}\n", " return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)\n", "\n", " _docs = dict(new=\"Create a new `DataBlock` with other `item_tfms` and `batch_tfms`\",\n", " datasets=\"Create a `Datasets` object from `source`\",\n", " dataloaders=\"Create a `DataLoaders` object from `source`\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To build a `DataBlock` you need to give the library four things: the types of your input/labels, and at least two functions: `get_items` and `splitter`. You may also need to include `get_x` and `get_y` or a more generic list of `getters` that are applied to the results of `get_items`.\n", "\n", "Once those are provided, you automatically get a `Datasets` or a `DataLoaders`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
DataBlock.datasets
[source]DataBlock.datasets
(**`source`**, **`verbose`**=*`False`*)\n",
"\n",
"Create a [`Datasets`](/data.core#Datasets) object from `source`"
],
"text/plain": [
"DataBlock.dataloaders
[source]DataBlock.dataloaders
(**`source`**, **`path`**=*`'.'`*, **`verbose`**=*`False`*, **`bs`**=*`64`*, **`shuffle`**=*`False`*, **`num_workers`**=*`None`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`drop_last`**=*`False`*, **`indexed`**=*`None`*, **`n`**=*`None`*, **`device`**=*`None`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*)\n",
"\n",
"Create a [`DataLoaders`](/data.core#DataLoaders) object from `source`"
],
"text/plain": [
"