{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp data.load" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.torch_basics import *\n", "from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind\n", "_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs = 4\n", "letters = list(string.ascii_lowercase)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DataLoader helpers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai includes a replacement for Pytorch's *DataLoader* which is largely API-compatible, and adds a lot of useful functionality and flexibility. Before we look at the class, there are a couple of helpers we'll need to define." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _wif(worker_id):\n", " set_num_threads(1)\n", " info = get_worker_info()\n", " ds = info.dataset.d\n", " ds.num_workers,ds.offs = info.num_workers,info.id\n", " set_seed(info.seed)\n", " ds.wif()\n", "\n", "class _FakeLoader:\n", " def _fn_noops(self, x=None, *args, **kwargs): return x\n", " \n", " _IterableDataset_len_called,_auto_collation,collate_fn,drop_last = None,False,_fn_noops,False\n", " _index_sampler,generator,prefetch_factor = Inf.count,None,2\n", " dataset_kind = _dataset_kind = _DatasetKind.Iterable\n", " \n", " def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers):\n", " self.dataset,self.default,self.worker_init_fn = self,d,_wif\n", " store_attr('d,pin_memory,num_workers,timeout,persistent_workers')\n", "\n", " def __iter__(self): return iter(self.d.create_batches(self.d.sample()))\n", "\n", " @property\n", " def multiprocessing_context(self): return (None,multiprocessing)[self.num_workers>0]\n", "\n", " @contextmanager\n", " def no_multiproc(self):\n", " old_num_workers = self.num_workers\n", " try:\n", " self.num_workers = 0\n", " yield self.d\n", " finally: self.num_workers = old_num_workers\n", "\n", "_collate_types = (ndarray, Tensor, typing.Mapping, str)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def fa_collate(t):\n", " \"A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s\"\n", " b = t[0]\n", " return (default_collate(t) if isinstance(b, _collate_types)\n", " else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)\n", " else default_collate(t))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#e.g. x is int, y is tuple\n", "t = [(1,(2,3)),(1,(2,3))]\n", "test_eq(fa_collate(t), default_collate(t))\n", "test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])\n", "\n", "t = [(1,(2,(3,4))),(1,(2,(3,4)))]\n", "test_eq(fa_collate(t), default_collate(t))\n", "test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])\n", "test_eq(L(fa_collate(t)[1]).map(type), [Tensor,tuple])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def fa_convert(t):\n", " \"A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s\"\n", " return (default_convert(t) if isinstance(t, _collate_types)\n", " else type(t)([fa_convert(s) for s in t]) if isinstance(t, Sequence)\n", " else default_convert(t))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t0 = array([1,2])\n", "t = [t0,(t0,t0)]\n", "\n", "test_eq(fa_convert(t), default_convert(t))\n", "test_eq(L(fa_convert(t)).map(type), [Tensor,tuple])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SkipItemException(Exception):\n", " \"Raised to notify `DataLoader` to skip an item\"\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
class
SkipItemException
[source]SkipItemException
() :: `Exception`\n",
"\n",
"Raised to notify [`DataLoader`](/data.load.html#DataLoader) to skip an item"
],
"text/plain": [
"