{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp data.core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.torch_basics import *\n", "from fastai.data.load import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data core\n", "\n", "> Core functionality for gathering data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The classes here provide functionality for applying a list of transforms to a set of items (`TfmdLists`, `Datasets`) or a `DataLoader` (`TfmdDl`) as well as the base class used to gather the data for model training: `DataLoaders`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TfmdDL -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_batch(x, y, samples, ctxs=None, max_n=9, **kwargs):\n", " if ctxs is None: ctxs = Inf.nones\n", " if hasattr(samples[0], 'show'):\n", " ctxs = [s.show(ctx=c, **kwargs) for s,c,_ in zip(samples,ctxs,range(max_n))]\n", " else:\n", " for i in range_of(samples[0]):\n", " ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n", " return ctxs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`show_batch` is a type-dispatched function that is responsible for showing decoded `samples`. `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of `show_batch` if `x` is a `TensorImage` or a `TensorText` for instance (see vision.core or text.data for more details). `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_results(x, y, samples, outs, ctxs=None, max_n=9, **kwargs):\n", " if ctxs is None: ctxs = Inf.nones\n", " for i in range(len(samples[0])):\n", " ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n", " for i in range(len(outs[0])):\n", " ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]\n", " return ctxs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`show_results` is a type-dispatched function that is responsible for showing decoded `samples` and their corresponding `outs`. Like in `show_batch`, `x` and `y` are the input and the target in the batch to be shown, and are passed along to dispatch on their types. `ctxs` can be passed but the function is responsible to create them if necessary. `kwargs` depend on the specific implementation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = [\"show_batch\", \"show_results\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_batch_tfms = ('after_item','before_batch','after_batch')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates()\n", "class TfmdDL(DataLoader):\n", " \"Transformed `DataLoader`\"\n", " def __init__(self, dataset, bs=64, shuffle=False, num_workers=None, verbose=False, do_setup=True, **kwargs):\n", " if num_workers is None: num_workers = min(16, defaults.cpus)\n", " for nm in _batch_tfms: kwargs[nm] = Pipeline(kwargs.get(nm,None))\n", " super().__init__(dataset, bs=bs, shuffle=shuffle, num_workers=num_workers, **kwargs)\n", " if do_setup:\n", " for nm in _batch_tfms:\n", " pv(f\"Setting up {nm}: {kwargs[nm]}\", verbose)\n", " kwargs[nm].setup(self)\n", "\n", " def _one_pass(self):\n", " b = self.do_batch([self.do_item(None)])\n", " if self.device is not None: b = to_device(b, self.device)\n", " its = self.after_batch(b)\n", " self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1\n", " self._types = explode_types(its)\n", "\n", " def _retain_dl(self,b):\n", " if not getattr(self, '_types', None): self._one_pass()\n", " return retain_types(b, typs=self._types)\n", "\n", " @delegates(DataLoader.new)\n", " def new(self, dataset=None, cls=None, **kwargs):\n", " res = super().new(dataset, cls, do_setup=False, **kwargs)\n", " if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):\n", " try:\n", " self._one_pass()\n", " res._n_inp,res._types = self._n_inp,self._types\n", " except: print(\"Could not do one pass in your dataloader, there is something wrong in it\")\n", " else: res._n_inp,res._types = self._n_inp,self._types\n", " return res\n", "\n", " def before_iter(self):\n", " super().before_iter()\n", " split_idx = getattr(self.dataset, 'split_idx', None)\n", " for nm in _batch_tfms:\n", " f = getattr(self,nm)\n", " if isinstance(f,Pipeline): f.split_idx=split_idx\n", "\n", " def decode(self, b): return to_cpu(self.after_batch.decode(self._retain_dl(b)))\n", " def decode_batch(self, b, max_n=9, full=True): return self._decode_batch(self.decode(b), max_n, full)\n", "\n", " def _decode_batch(self, b, max_n=9, full=True):\n", " f = self.after_item.decode\n", " f1 = self.before_batch.decode\n", " f = compose(f1, f, partial(getattr(self.dataset,'decode',noop), full = full))\n", " return L(batch_to_samples(b, max_n=max_n)).map(f)\n", "\n", " def _pre_show_batch(self, b, max_n=9):\n", " \"Decode `b` to be ready for `show_batch`\"\n", " b = self.decode(b)\n", " if hasattr(b, 'show'): return b,None,None\n", " its = self._decode_batch(b, max_n, full=False)\n", " if not is_listy(b): b,its = [b],L((o,) for o in its)\n", " return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its\n", "\n", " def show_batch(self, b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs):\n", " if unique:\n", " old_get_idxs = self.get_idxs\n", " self.get_idxs = lambda: Inf.zeros\n", " if b is None: b = self.one_batch()\n", " if not show: return self._pre_show_batch(b, max_n=max_n)\n", " show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)\n", " if unique: self.get_idxs = old_get_idxs\n", "\n", " def show_results(self, b, out, max_n=9, ctxs=None, show=True, **kwargs):\n", " x,y,its = self.show_batch(b, max_n=max_n, show=False)\n", " b_out = type(b)(b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,)))\n", " x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)\n", " res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))\n", " if not show: return res\n", " show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)\n", "\n", " @property\n", " def n_inp(self):\n", " if hasattr(self.dataset, 'n_inp'): return self.dataset.n_inp\n", " if not hasattr(self, '_n_inp'): self._one_pass()\n", " return self._n_inp\n", "\n", " def to(self, device):\n", " self.device = device\n", " for tfm in self.after_batch.fs:\n", " for a in L(getattr(tfm, 'parameters', None)): setattr(tfm, a, getattr(tfm, a).to(device))\n", " return self" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `TfmdDL` is a `DataLoader` that creates `Pipeline` from a list of `Transform`s for the callbacks `after_item`, `before_batch` and `after_batch`. As a result, it can decode or show a processed `batch`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "add_docs(TfmdDL,\n", " decode=\"Decode `b` using `tfms`\",\n", " decode_batch=\"Decode `b` entirely\",\n", " new=\"Create a new version of self with a few changed attributes\",\n", " show_batch=\"Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)\",\n", " show_results=\"Show each item of `b` and `out`\",\n", " before_iter=\"override\",\n", " to=\"Put self and its transforms state on `device`\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _Category(int, ShowTitle): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test retain type\n", "class NegTfm(Transform):\n", " def encodes(self, x): return torch.neg(x)\n", " def decodes(self, x): return torch.neg(x)\n", " \n", "tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)\n", "b = tdl.one_batch()\n", "test_eq(type(b[0]), TensorImage)\n", "b = (tensor([1.,1.,1.,1.]),)\n", "test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(Transform): \n", " def encodes(self, x): return x \n", " def decodes(self, x): return TitledInt(x) \n", "\n", "@Transform\n", "def f(x)->None: return fastuple((x,x))\n", "\n", "start = torch.arange(50)\n", "test_eq_type(f(2), fastuple((2,2)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = A()\n", "tdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)\n", "x,y = tdl.one_batch()\n", "test_eq(type(y), fastuple)\n", "\n", "s = tdl.decode_batch((x,y))\n", "test_eq(type(s[0][1]), fastuple)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)\n", "test_eq(tdl.dataset[0], start[0])\n", "test_eq(len(tdl), (50-1)//4+1)\n", "test_eq(tdl.bs, 4)\n", "test_stdout(tdl.show_batch, '0\\n1\\n2\\n3')\n", "test_stdout(partial(tdl.show_batch, unique=True), '0\\n0\\n0\\n0')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class B(Transform):\n", " parameters = 'a'\n", " def __init__(self): self.a = torch.tensor(0.)\n", " def encodes(self, x): x\n", " \n", "tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=B(), bs=4)\n", "test_eq(tdl.after_batch.fs[0].a.device, torch.device('cpu'))\n", "tdl.to(default_device())\n", "test_eq(tdl.after_batch.fs[0].a.device, default_device())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
DataLoader.one_batch
[source]DataLoader.one_batch
()\n",
"\n",
"Return one batch from [`DataLoader`](/data.load.html#DataLoader)."
],
"text/plain": [
"TfmdDL.decode
[source]TfmdDL.decode
(**`b`**)\n",
"\n",
"Decode `b` using `tfms`"
],
"text/plain": [
"TfmdDL.decode_batch
[source]TfmdDL.decode_batch
(**`b`**, **`max_n`**=*`9`*, **`full`**=*`True`*)\n",
"\n",
"Decode `b` entirely"
],
"text/plain": [
"TfmdDL.show_batch
[source]TfmdDL.show_batch
(**`b`**=*`None`*, **`max_n`**=*`9`*, **`ctxs`**=*`None`*, **`show`**=*`True`*, **`unique`**=*`False`*, **\\*\\*`kwargs`**)\n",
"\n",
"Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a [`DataLoader`](/data.load.html#DataLoader))"
],
"text/plain": [
"TfmdDL.to
[source]TfmdDL.to
(**`device`**)\n",
"\n",
"Put self and its transforms state on `device`"
],
"text/plain": [
"DataLoaders.__getitem__
[source]DataLoaders.__getitem__
(**`i`**)\n",
"\n",
"Retrieve [`DataLoader`](/data.load.html#DataLoader) at `i` (`0` is training, `1` is validation)"
],
"text/plain": [
"DataLoaders.train
[source]DataLoaders.valid
[source]DataLoaders.train_ds
[source]DataLoaders.valid_ds
[source]FilteredBase.dataloaders
[source]FilteredBase.dataloaders
(**`bs`**=*`64`*, **`shuffle_train`**=*`None`*, **`shuffle`**=*`True`*, **`val_shuffle`**=*`False`*, **`n`**=*`None`*, **`path`**=*`'.'`*, **`dl_type`**=*`None`*, **`dl_kwargs`**=*`None`*, **`device`**=*`None`*, **`drop_last`**=*`None`*, **`val_bs`**=*`None`*, **`num_workers`**=*`None`*, **`verbose`**=*`False`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`indexed`**=*`None`*, **`persistent_workers`**=*`False`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*)\n",
"\n"
],
"text/plain": [
"TfmdLists.subset
[source]TfmdLists.subset
(**`i`**)\n",
"\n",
"New [`TfmdLists`](/data.core.html#TfmdLists) with same tfms that only includes items in `i`th split"
],
"text/plain": [
"TfmdLists.infer_idx
[source]TfmdLists.infer_idx
(**`x`**)\n",
"\n",
"Finds the index where `self.tfms` can be applied to `x`, depending on the type of `x`"
],
"text/plain": [
"TfmdLists.infer
[source]TfmdLists.infer
(**`x`**)\n",
"\n",
"Apply `self.tfms` to `x` starting at the right tfm depending on the type of `x`"
],
"text/plain": [
"Datasets.dataloaders
[source]Datasets.dataloaders
(**`bs`**=*`64`*, **`shuffle_train`**=*`None`*, **`shuffle`**=*`True`*, **`val_shuffle`**=*`False`*, **`n`**=*`None`*, **`path`**=*`'.'`*, **`dl_type`**=*`None`*, **`dl_kwargs`**=*`None`*, **`device`**=*`None`*, **`drop_last`**=*`None`*, **`val_bs`**=*`None`*, **`num_workers`**=*`None`*, **`verbose`**=*`False`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`indexed`**=*`None`*, **`persistent_workers`**=*`False`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*)\n",
"\n",
"Get a [`DataLoaders`](/data.core.html#DataLoaders)"
],
"text/plain": [
"Datasets.decode
[source]Datasets.decode
(**`o`**, **`full`**=*`True`*)\n",
"\n",
"Compose `decode` of all `tuple_tfms` then all `tfms` on `i`"
],
"text/plain": [
"Datasets.show
[source]Datasets.show
(**`o`**, **`ctx`**=*`None`*, **\\*\\*`kwargs`**)\n",
"\n",
"Show item `o` in `ctx`"
],
"text/plain": [
"Datasets.new_empty
[source]Datasets.new_empty
()\n",
"\n",
"Create a new empty version of the `self`, keeping only the transforms"
],
"text/plain": [
"