{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp distributed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.basics import *\n", "from local.callback.progress import ProgressCallback\n", "from torch.nn.parallel import DistributedDataParallel, DataParallel\n", "from torch.utils.data.distributed import DistributedSampler" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.test import * " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Distributed and parallel training\n", "\n", "> Callbacks and helper functions to train in parallel or use distributed training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parallel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Patch the parallel models so they work with RNNs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def reset(self: DataParallel):\n", " if hasattr(self.module, 'reset'): self.module.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ParallelTrainer(Callback):\n", " run_after,run_before = TrainEvalCallback,Recorder\n", " def __init__(self, device_ids): self.device_ids = device_ids\n", " def begin_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)\n", " def after_fit(self): self.learn.model = self.learn.model.module" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def to_parallel(self: Learner, device_ids=None): \n", " self.add_cb(ParallelTrainer(device_ids))\n", " return self" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Distributed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Patch the parallel models so they work with RNNs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def reset(self: DistributedDataParallel):\n", " if hasattr(self.module, 'reset'): self.module.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def setup_distrib(gpu=None):\n", " if gpu is None: return gpu\n", " gpu = int(gpu)\n", " torch.cuda.set_device(int(gpu))\n", " if num_distrib() > 1:\n", " torch.distributed.init_process_group(backend='nccl', init_method='env://')\n", " return gpu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to change the dataloaders so that they only get one part of the batch each (otherwise tehre is not point in using distributed training)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates()\n", "class DistributedDL(TfmdDL):\n", " \n", " def __init__(self, dataset, rank, world_size, **kwargs):\n", " super().__init__(dataset, **kwargs)\n", " if self.n%world_size != 0: self.n += world_size-self.n%world_size\n", " self.total_n,self.n = self.n,self.n//world_size\n", " store_attr(self, 'rank,world_size')\n", " \n", " def get_idxs(self):\n", " idxs = Inf.count if self.indexed else Inf.nones\n", " return idxs if self.n is None else list(itertools.islice(idxs, self.total_n))\n", " \n", " def shuffle_fn(self, idxs):\n", " \"Deterministically shuffle on each training process based on epoch.\"\n", " g = torch.Generator()\n", " g.manual_seed(self.epoch)\n", " return L(idxs)[torch.randperm(self.total_n, generator=g)]\n", " \n", " def sample(self):\n", " idxs = self.get_idxs()\n", " if self.shuffle: idxs = self.shuffle_fn(idxs)\n", " # add extra samples to make it evenly divisible\n", " idxs += idxs[:(self.total_n - len(idxs))]\n", " # subsample\n", " idxs = idxs[self.rank:self.total_n:self.world_size]\n", " return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)\n", " \n", " def create_item(self, s):\n", " if s is not None and s >= len(self.dataset): s = s%len(self.dataset)\n", " return super().create_item(s)\n", " \n", " def set_epoch(self, epoch): self.epoch = epoch\n", " \n", " @classmethod\n", " def from_dl(cls, dl, rank, world_size, **kwargs):\n", " cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout,\n", " bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed)\n", " cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in \"sample shuffle_fn create_item\".split()})\n", " return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = TfmdDL(list(range(50)), bs=16, num_workers=2)\n", "for i in range(4):\n", " dl1 = DistributedDL.from_dl(dl, i, 4)\n", " test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)\n", "res = []\n", "for i in range(4):\n", " dl1 = DistributedDL.from_dl(dl, i, 4)\n", " dl1.set_epoch(0)\n", " res += list(dl1)[0].tolist()\n", "#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle\n", "test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DistributedTrainer(Callback):\n", " run_after,run_before = TrainEvalCallback,Recorder\n", " def __init__(self, cuda_id=0): self.cuda_id = cuda_id\n", "\n", " def begin_fit(self):\n", " self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)\n", " self.old_dls = [dl for dl in self.dbunch.dls]\n", " self.learn.dbunch.dls = [DistributedDL.from_dl(dl, rank_distrib(), num_distrib()) for dl in self.dbunch.dls]\n", " if rank_distrib() > 0: self.learn.logger=noop\n", "\n", " def begin_epoch(self): \n", " for dl in self.dbunch.dls: dl.set_epoch(self.epoch)\n", "\n", " def after_fit(self):\n", " self.learn.model = self.learn.model.module\n", " self.learn.dbunch.dls = self.old_dls" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def to_distributed(self: Learner, cuda_id): \n", " self.add_cb(DistributedTrainer(cuda_id))\n", " if rank_distrib() > 0: self.remove_cb(self.progress)\n", " return self" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core.ipynb.\n", "Converted 01a_utils.ipynb.\n", "Converted 01b_dispatch.ipynb.\n", "Converted 01c_transform.ipynb.\n", "Converted 02_script.ipynb.\n", "Converted 03_torch_core.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_dataloader.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }