#export
from fastai.torch_basics import *
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

bs = 4
letters = list(string.ascii_lowercase)

## DataLoader helpers

fastai includes a replacement for Pytorch's *DataLoader* which is largely API-compatible, and adds a lot of useful functionality and flexibility. Before we look at the class, there are a couple of helpers we'll need to define.

#export
def _wif(worker_id):
    set_num_threads(1)
    info = get_worker_info()
    ds = info.dataset.d
    ds.num_workers,ds.offs = info.num_workers,info.id
    set_seed(info.seed)
    ds.wif()

class _FakeLoader:
    def _fn_noops(self, x=None, *args, **kwargs): return x
    
    _IterableDataset_len_called,_auto_collation,collate_fn,drop_last = None,False,_fn_noops,False
    _index_sampler,generator,prefetch_factor = Inf.count,None,2
    dataset_kind = _dataset_kind = _DatasetKind.Iterable
    
    def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers):
        self.dataset,self.default,self.worker_init_fn = self,d,_wif
        store_attr('d,pin_memory,num_workers,timeout,persistent_workers')

    def __iter__(self): return iter(self.d.create_batches(self.d.sample()))

    @property
    def multiprocessing_context(self): return (None,multiprocessing)[self.num_workers>0]

    @contextmanager
    def no_multiproc(self):
        old_num_workers = self.num_workers
        try:
            self.num_workers = 0
            yield self.d
        finally: self.num_workers = old_num_workers

_collate_types = (ndarray, Tensor, typing.Mapping, str)

#export
def fa_collate(t):
    "A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s"
    b = t[0]
    return (default_collate(t) if isinstance(b, _collate_types)
            else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
            else default_collate(t))

#e.g. x is int, y is tuple
t = [(1,(2,3)),(1,(2,3))]
test_eq(fa_collate(t), default_collate(t))
test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])

t = [(1,(2,(3,4))),(1,(2,(3,4)))]
test_eq(fa_collate(t), default_collate(t))
test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])
test_eq(L(fa_collate(t)[1]).map(type), [Tensor,tuple])

#export
def fa_convert(t):
    "A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s"
    return (default_convert(t) if isinstance(t, _collate_types)
            else type(t)([fa_convert(s) for s in t]) if isinstance(t, Sequence)
            else default_convert(t))

t0 = array([1,2])
t = [t0,(t0,t0)]

test_eq(fa_convert(t), default_convert(t))
test_eq(L(fa_convert(t)).map(type), [Tensor,tuple])

#export
class SkipItemException(Exception):
    "Raised to notify `DataLoader` to skip an item"
    pass

class SkipItemException[source]

\n", "\n", "> SkipItemException() :: `Exception`\n", "\n", "Raised to notify [`DataLoader`](/data.load.html#DataLoader) to skip an item" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SkipItemException, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DataLoader -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@funcs_kwargs\n", "class DataLoader(GetAttr):\n", " _noop_methods = 'wif before_iter after_item before_batch after_batch after_iter'.split()\n", " for o in _noop_methods: exec(f\"def {o}(self, x=None, *args, **kwargs): return x\")\n", " _methods = _noop_methods + 'create_batches create_item create_batch retain \\\n", " get_idxs sample shuffle_fn do_batch create_batch'.split()\n", " _default = 'dataset'\n", " def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,\n", " shuffle=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, **kwargs):\n", " if batch_size is not None: bs = batch_size # PyTorch compatibility\n", " assert not (bs is None and drop_last)\n", " if indexed is None: indexed = (hasattr(dataset,'__getitem__')\n", " and not isinstance(dataset, IterableDataset))\n", " if not indexed and shuffle: raise ValueError(\"Can only shuffle an indexed dataset (not an iterable one).\")\n", " if n is None:\n", " try: n = len(dataset)\n", " except TypeError: pass\n", " store_attr('dataset,bs,shuffle,drop_last,indexed,n,pin_memory,timeout,device')\n", " self.rng,self.num_workers,self.offs = random.Random(random.randint(0,2**32-1)),1,0\n", " if sys.platform == \"win32\" and IN_NOTEBOOK and num_workers > 0:\n", " print(\"Due to IPython and Windows limitation, python multiprocessing isn't available now.\")\n", " print(\"So `number_workers` is changed to 0 to avoid getting stuck\")\n", " num_workers = 0 \n", " self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers)\n", "\n", " def __len__(self):\n", " if self.n is None: raise TypeError\n", " if self.bs is None: return self.n\n", " return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)\n", "\n", " def get_idxs(self):\n", " idxs = Inf.count if self.indexed else Inf.nones\n", " if self.n is not None: idxs = list(itertools.islice(idxs, self.n))\n", " if self.shuffle: idxs = self.shuffle_fn(idxs)\n", " return idxs\n", " \n", " def sample(self): \n", " return (b for i,b in enumerate(self.__idxs) if i//(self.bs or 1)%self.num_workers==self.offs)\n", "\n", " def __iter__(self):\n", " self.randomize()\n", " self.before_iter()\n", " self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)\n", " for b in _loaders[self.fake_l.num_workers==0](self.fake_l):\n", " if self.device is not None: b = to_device(b, self.device)\n", " yield self.after_batch(b)\n", " self.after_iter()\n", " if hasattr(self, 'it'): del(self.it)\n", "\n", " def create_batches(self, samps):\n", " if self.dataset is not None: self.it = iter(self.dataset)\n", " res = filter(lambda o:o is not None, map(self.do_item, samps))\n", " yield from map(self.do_batch, self.chunkify(res))\n", "\n", " def new(self, dataset=None, cls=None, **kwargs):\n", " if dataset is None: dataset = self.dataset\n", " if cls is None: cls = type(self)\n", " cur_kwargs = dict(dataset=dataset, num_workers=self.fake_l.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,\n", " bs=self.bs, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)\n", " for n in self._methods:\n", " o = getattr(self, n)\n", " if not isinstance(o, MethodType): cur_kwargs[n] = o\n", " return cls(**merge(cur_kwargs, kwargs))\n", "\n", " @property\n", " def prebatched(self): return self.bs is None\n", " def do_item(self, s):\n", " try: return self.after_item(self.create_item(s))\n", " except SkipItemException: return None\n", " def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)\n", " def shuffle_fn(self, idxs): return self.rng.sample(idxs, len(idxs))\n", " def randomize(self): self.rng = random.Random(self.rng.randint(0,2**32-1))\n", " def retain(self, res, b): return retain_types(res, b[0] if is_listy(b) else b)\n", " def create_item(self, s):\n", " if self.indexed: return self.dataset[s or 0]\n", " elif s is None: return next(self.it)\n", " else: raise IndexError(\"Cannot index an iterable dataset numerically - must use `None`.\")\n", " def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)\n", " def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)\n", " def to(self, device): self.device = device\n", " def one_batch(self):\n", " if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')\n", " with self.fake_l.no_multiproc(): res = first(self)\n", " if hasattr(self, 'it'): delattr(self, 'it')\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "add_docs(DataLoader, \"API compatible with PyTorch DataLoader, with a lot more callbacks and flexibility\",\n", " get_idxs = \"Return a list of indices to reference the dataset. Calls `shuffle_fn` internally if `shuffle=True`.\",\n", " sample = \"Same as `get_idxs` but returns a generator of indices to reference the dataset.\",\n", " create_batches = \"Takes output of `sample` as input, and returns batches of data. Does not apply `after_batch`.\",\n", " new = \"Create a new `DataLoader` with given arguments keeping remaining arguments same as original `DataLoader`.\",\n", " prebatched = \"Check if `bs` is None.\",\n", " do_item = \"Combines `after_item` and `create_item` to get an item from dataset by providing index as input.\",\n", " chunkify = \"Used by `create_batches` to turn generator of items (`b`) into batches.\",\n", " shuffle_fn = \"Returns a random permutation of `idxs`.\",\n", " randomize = \"Set's `DataLoader` random number generator state.\",\n", " retain = \"Cast each item of `res` to type of matching item in `b` if its a superclass.\",\n", " create_item = \"Subset of the dataset containing the index values of sample if exists, else next iterator.\",\n", " create_batch = \"Collate a list of items into a batch.\",\n", " do_batch = \"Combines `create_batch` and `before_batch` to get a batch of items. Input is a list of items to collate.\",\n", " to = \"Sets `self.device=device`.\",\n", " one_batch = \"Return one batch from `DataLoader`.\",\n", " wif = \"See pytorch `worker_init_fn` for details.\", \n", " before_iter = \"Called before `DataLoader` starts to read/iterate over the dataset.\",\n", " after_item = \"Takes output of `create_item` as input and applies this function on it.\",\n", " before_batch = \"It is called before collating a list of items into a batch. Input is a list of items.\",\n", " after_batch = \"After collating mini-batch of items, the mini-batch is passed through this function.\",\n", " after_iter = \"Called after `DataLoader` has fully read/iterated over the dataset.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Arguments to `DataLoader`:\n", "* `dataset`: dataset from which to load the data. Can be either map-style or iterable-style dataset.\n", "* `bs` (int): how many samples per batch to load (if `batch_size` is provided then `batch_size` will override `bs`). If `bs=None`, then it is assumed that `dataset.__getitem__` returns a batch.\n", "* `num_workers` (int): how many subprocesses to use for data loading. `0` means that the data will be loaded in the main process.\n", "* `pin_memory` (bool): If `True`, the data loader will copy Tensors into CUDA pinned memory before returning them.\n", "* `timeout` (float>0): the timeout value in seconds for collecting a batch from workers.\n", "* `batch_size` (int): It is only provided for PyTorch compatibility. Use `bs`.\n", "* `shuffle` (bool): If `True`, then data is shuffled every time dataloader is fully read/iterated.\n", "* `drop_last` (bool): If `True`, then the last incomplete batch is dropped.\n", "* `indexed` (bool): The `DataLoader` will make a guess as to whether the dataset can be indexed (or is iterable), but you can override it with this parameter. `True` by default.\n", "* `n` (int): Defaults to `len(dataset)`. If you are using iterable-style dataset, you can specify the size with `n`.\n", "* `device` (torch.device): Defaults to `default_device()` which is CUDA by default. You can specify device as `torch.device('cpu')`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override `item` and use the default infinite sampler to get a stream of unknown length (`stop()` when you want to stop the stream)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#1) [0.8184584259771384]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class RandDL(DataLoader):\n", " def create_item(self, s):\n", " r = random.random()\n", " return r if r<0.95 else stop()\n", "\n", "L(RandDL())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#9) [4,4,4,4,4,4,4,4,4]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L(RandDL(bs=4, drop_last=True)).map(len)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#16) [4,4,4,4,4,4,4,4,4,4...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dl = RandDL(bs=4, num_workers=4, drop_last=True)\n", "L(dl).map(len)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_num_workers = 0 if sys.platform == \"win32\" else 4\n", "test_eq(dl.fake_l.num_workers, test_num_workers)\n", "with dl.fake_l.no_multiproc(): \n", " test_eq(dl.fake_l.num_workers, 0)\n", " L(dl).map(len)\n", "test_eq(dl.fake_l.num_workers, test_num_workers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#15) [0.09997947451517075,0.9406352068577291,0.7365644613807534,0.03467093036515401,0.33835398224425894,0.8449277297768522,0.38410576483536174,0.6763438764319878,0.748649743338762,0.27015475248670817...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def _rand_item(s):\n", " r = random.random()\n", " return r if r<0.95 else stop()\n", "\n", "L(DataLoader(create_item=_rand_item))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you don't set `bs`, then `dataset` is assumed to provide an iterator or a `__getitem__` that returns a batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds1 = DataLoader(letters)\n", "test_eq(L(ds1), letters)\n", "test_eq(len(ds1), 26)\n", "\n", "test_shuffled(L(DataLoader(letters, shuffle=True)), letters)\n", "\n", "ds1 = DataLoader(letters, indexed=False)\n", "test_eq(L(ds1), letters)\n", "test_eq(len(ds1), 26)\n", "\n", "t2 = L(tensor([0,1,2]),tensor([3,4,5]))\n", "ds2 = DataLoader(t2)\n", "test_eq_type(L(ds2), t2)\n", "\n", "t3 = L(array([0,1,2], dtype=np.int64),array([3,4,5], dtype=np.int64))\n", "ds3 = DataLoader(t3)\n", "test_eq_type(L(ds3), t3.map(tensor))\n", "\n", "ds4 = DataLoader(t3, create_batch=noop, after_iter=lambda: setattr(t3, 'f', 1))\n", "test_eq_type(L(ds4), t3)\n", "test_eq(t3.f, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you do set `bs`, then `dataset` is assumed to provide an iterator or a `__getitem__` that returns a single item of a batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def twoepochs(d): return ' '.join(''.join(list(o)) for _ in range(2) for o in d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds1 = DataLoader(letters, bs=4, drop_last=True, num_workers=0)\n", "test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')\n", "\n", "ds1 = DataLoader(letters,4,num_workers=2)\n", "test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')\n", "\n", "ds1 = DataLoader(range(12), bs=4, num_workers=3)\n", "test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))\n", "\n", "ds1 = DataLoader([str(i) for i in range(11)], bs=4, after_iter=lambda: setattr(t3, 'f', 2))\n", "test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))\n", "test_eq(t3.f, 2)\n", "\n", "it = iter(DataLoader(map(noop,range(20)), bs=4, num_workers=1))\n", "test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Iterable dataloaders require specific tests." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DummyIterableDataset(IterableDataset):\n", " def __iter__(self):\n", " yield from range(11)\n", "\n", "ds1 = DataLoader(DummyIterableDataset(), bs=4)\n", "# Check it yields fine, and check we can do multiple passes\n", "for i in range(3):\n", " test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10])))\n", "\n", "# Check `drop_last` works fine (with multiple passes, since this will prematurely terminate the iterator)\n", "ds1 = DataLoader(DummyIterableDataset(), bs=4, drop_last=True)\n", "for i in range(3):\n", " test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.27 ms, sys: 1.05 ms, total: 5.32 ms\n", "Wall time: 316 ms\n", "CPU times: user 12.7 ms, sys: 11.9 ms, total: 24.5 ms\n", "Wall time: 197 ms\n", "CPU times: user 14.5 ms, sys: 16.2 ms, total: 30.7 ms\n", "Wall time: 127 ms\n" ] }, { "data": { "text/plain": [ "(#26) ['r','c','q','n','j','s','l','p','b','y'...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class SleepyDL(list):\n", " def __getitem__(self,i):\n", " time.sleep(random.random()/50)\n", " return super().__getitem__(i)\n", "\n", "t = SleepyDL(letters)\n", "\n", "%time test_eq(DataLoader(t, num_workers=0), letters)\n", "%time test_eq(DataLoader(t, num_workers=2), letters)\n", "%time test_eq(DataLoader(t, num_workers=4), letters)\n", "\n", "dl = DataLoader(t, shuffle=True, num_workers=1)\n", "test_shuffled(L(dl), letters)\n", "test_shuffled(L(dl), L(dl))\n", "L(dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 12 ms, sys: 22.3 ms, total: 34.3 ms\n", "Wall time: 130 ms\n" ] } ], "source": [ "class SleepyQueue():\n", " \"Simulate a queue with varying latency\"\n", " def __init__(self, q): self.q=q\n", " def __iter__(self):\n", " while True:\n", " time.sleep(random.random()/100)\n", " try: yield self.q.get_nowait()\n", " except queues.Empty: return\n", "\n", "q = Queue()\n", "for o in range(30): q.put(o)\n", "it = SleepyQueue(q)\n", "\n", "if not (sys.platform == \"win32\" and IN_NOTEBOOK):\n", " %time test_shuffled(L(DataLoader(it, num_workers=4)), L(range(30)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(TensorBase): pass\n", "\n", "for nw in (0,2):\n", " t = A(tensor([1,2]))\n", " dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)\n", " b = first(dl)\n", " test_eq(type(b), A)\n", "\n", " t = (A(tensor([1,2])),)\n", " dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)\n", " b = first(dl)\n", " test_eq(type(b[0]), A)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([16, 14, 5, 1, 39, 49, 10, 40, 7, 36, 28, 42, 32, 24, 43, 46, 4, 3,\n", " 11, 48, 26, 35, 15, 25, 23, 8, 44, 47, 0, 34, 21, 17]),\n", " tensor([45, 41, 6, 20, 38, 19, 29, 37, 13, 18, 2, 27, 30, 12, 33, 22, 9, 31])]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(DataLoader(list(range(50)),bs=32,shuffle=True,num_workers=3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class A(TensorBase): pass\n", "t = A(tensor(1,2))\n", "\n", "tdl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=2, after_batch=to_device)\n", "b = first(tdl)\n", "test_eq(type(b), A)\n", "\n", "# Unknown attributes are delegated to `dataset`\n", "test_eq(tdl.pop(), tensor(1,2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override `get_idxs` to return the same index until consumption of the DL. 