{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp data.core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.torch_basics import *\n", "from fastai2.test import *\n", "from fastai2.data.load import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data core\n", "\n", "> Core functionality for gathering data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The classes here provide functionality for applying a list of transforms to a set of items (`TfmdList`, `DataSource`) or a `DataLoader` (`TfmdDl`) as well as the base class used to gather the data for model training: `DataBunch`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TfmdDL -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_batch(x, y, samples, ctxs=None, max_n=9, **kwargs):\n", " if ctxs is None: ctxs = Inf.nones\n", " for i in range_of(samples[0]):\n", " ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n", " return ctxs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`show_batch` is a type-dispatched function that is responsible for showing decoded `samples`. `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of `show_batch` if `x` is a `TensorImage` or a `TensorText` for instance (see vision.core or text.data for more details). `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_results(x, y, samples, outs, ctxs=None, max_n=9, **kwargs):\n", " if ctxs is None: ctxs = Inf.nones\n", " for i in range(len(samples[0])):\n", " ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n", " for i in range(len(outs[0])):\n", " ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]\n", " return ctxs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`show_results` is a type-dispatched function that is responsible for showing decoded `samples` and their corresponding `outs`. Like in `show_batch`, `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = [\"show_batch\", \"show_results\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_batch_tfms = ('after_item','before_batch','after_batch')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates()\n", "class TfmdDL(DataLoader):\n", " \"Transformed `DataLoader`\"\n", " def __init__(self, dataset, bs=16, shuffle=False, num_workers=None, **kwargs):\n", " if num_workers is None: num_workers = min(16, defaults.cpus)\n", " for nm in _batch_tfms:\n", " kwargs[nm] = Pipeline(kwargs.get(nm,None), as_item=(nm=='before_batch'))\n", " super().__init__(dataset, bs=bs, shuffle=shuffle, num_workers=num_workers, **kwargs)\n", " for nm in _batch_tfms: kwargs[nm].setup(self)\n", "\n", " def _one_pass(self):\n", " its = self.after_batch(self.do_batch([self.do_item(0)]))\n", " self._device = find_device(its)\n", " self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1\n", " self._retain_dl = partial(retain_types, typs=mapped(type,its))\n", "\n", " def _retain_dl(self,b):\n", " self._one_pass()\n", " # we just replaced ourselves, so this is *not* recursive! :)\n", " return self._retain_dl(b)\n", "\n", " def before_iter(self):\n", " super().before_iter()\n", " split_idx = getattr(self.dataset, 'split_idx', None)\n", " for nm in _batch_tfms:\n", " f = getattr(self,nm)\n", " if isinstance(f,Pipeline): f.split_idx=split_idx\n", "\n", " def decode(self, b): return self.before_batch.decode(self.after_batch.decode(self._retain_dl(b)))\n", " def decode_batch(self, b, max_n=9, full=True): return self._decode_batch(self.decode(b), max_n, full)\n", "\n", " def _decode_batch(self, b, max_n=9, full=True):\n", " f = self.after_item.decode\n", " f = compose(f, partial(getattr(self.dataset,'decode',noop), full = full))\n", " return L(batch_to_samples(b, max_n=max_n)).map(f)\n", "\n", " def _pre_show_batch(self, b, max_n=9):\n", " \"Decode `b` to be ready for `show_batch`\"\n", " b = self.decode(b)\n", " if hasattr(b, 'show'): return b,None,None\n", " its = self._decode_batch(b, max_n, full=False)\n", " if not is_listy(b): b,its = [b],L((o,) for o in its)\n", " return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its\n", "\n", " def show_batch(self, b=None, max_n=9, ctxs=None, show=True, **kwargs):\n", " if b is None: b = self.one_batch()\n", " if not show: return self._pre_show_batch(b, max_n=max_n)\n", " show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)\n", "\n", " def show_results(self, b, out, max_n=9, ctxs=None, show=True, **kwargs):\n", " x,y,its = self.show_batch(b, max_n=max_n, show=False)\n", " b_out = b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,))\n", " x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)\n", " res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))\n", " if not show: return res\n", " show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)\n", " \n", " @property\n", " def device(self):\n", " if not hasattr(self, '_device'): _ = self._one_pass()\n", " return self._device\n", "\n", " @property\n", " def n_inp(self):\n", " if hasattr(self.dataset, 'n_inp'): return self.dataset.n_inp\n", " if not hasattr(self, '_n_inp'): self._one_pass()\n", " return self._n_inp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `TfmdDL` is a `DataLoader` that creates `Pipeline` from a list of `Transform`s for the callbacks `after_item`, `before_batch` and `after_batch`. As a result, it can decode or show a processed `batch`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "add_docs(TfmdDL,\n", " decode=\"Decode `b` using `tfms`\",\n", " decode_batch=\"Decode `b` entirely\",\n", " show_batch=\"Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)\",\n", " show_results=\"Show each item of `b` and `out`\",\n", " before_iter=\"override\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _Category(int, ShowTitle): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test retain type\n", "class NegTfm(Transform):\n", " def encodes(self, x): return torch.neg(x)\n", " def decodes(self, x): return torch.neg(x)\n", " \n", "tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)\n", "b = tdl.one_batch()\n", "test_eq(type(b[0]), TensorImage)\n", "b = (tensor([1.,1.,1.,1.]),)\n", "test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): \n", " def encodes(self, x): return x \n", " def decodes(self, x): return Int(x) \n", "\n", "@Transform\n", "def f(x)->None: return Tuple((x,x))\n", "\n", "start = torch.arange(50)\n", "test_eq_type(f(2), Tuple((2,2)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = A()\n", "tdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)\n", "x,y = tdl.one_batch()\n", "test_eq(type(y), Tuple)\n", "\n", "s = tdl.decode_batch((x,y))\n", "test_eq(type(s[0][1]), Tuple)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)\n", "test_eq(tdl.dataset[0], start[0])\n", "test_eq(len(tdl), (50-1)//4+1)\n", "test_eq(tdl.bs, 4)\n", "test_stdout(tdl.show_batch, '0\\n1\\n2\\n3')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
DataLoader.one_batch[source]DataLoader.one_batch()\n",
"\n"
],
"text/plain": [
"TfmdDL.decode[source]TfmdDL.decode(**`b`**)\n",
"\n",
"Decode `b` using `tfms`"
],
"text/plain": [
"TfmdDL.decode_batch[source]TfmdDL.decode_batch(**`b`**, **`max_n`**=*`9`*, **`full`**=*`True`*)\n",
"\n",
"Decode `b` entirely"
],
"text/plain": [
"TfmdDL.show_batch[source]TfmdDL.show_batch(**`b`**=*`None`*, **`max_n`**=*`9`*, **`ctxs`**=*`None`*, **`show`**=*`True`*, **\\*\\*`kwargs`**)\n",
"\n",
"Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a [`DataLoader`](/data.load.html#DataLoader))"
],
"text/plain": [
"DataBunch.__getitem__[source]DataBunch.__getitem__(**`i`**)\n",
"\n",
"Retrieve [`DataLoader`](/data.load.html#DataLoader) at `i` (`0` is training, `1` is validation)"
],
"text/plain": [
"train_dl[source]valid_dl[source]train_ds[source]valid_ds[source]TfmdList.subset[source]TfmdList.subset(**`i`**)\n",
"\n",
"New [`TfmdList`](/data.core.html#TfmdList) with same tfms that only includes items in `i`th split"
],
"text/plain": [
"DataSource.decode[source]DataSource.decode(**`o`**, **`full`**=*`True`*)\n",
"\n",
"Compose `decode` of all `tuple_tfms` then all `tfms` on `i`"
],
"text/plain": [
"DataSource.show[source]DataSource.show(**`o`**, **`ctx`**=*`None`*, **\\*\\*`kwargs`**)\n",
"\n",
"Show item `o` in `ctx`"
],
"text/plain": [
"