{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp distributed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.basics import *\n", "from fastai2.callback.progress import ProgressCallback\n", "from torch.nn.parallel import DistributedDataParallel, DataParallel\n", "from torch.utils.data.distributed import DistributedSampler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Distributed and parallel training\n", "\n", "> Callbacks and helper functions to train in parallel or use distributed training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parallel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Patch the parallel models so they work with RNNs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def reset(self: DataParallel):\n", " if hasattr(self.module, 'reset'): self.module.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args\n", "class ParallelTrainer(Callback):\n", " run_after,run_before = TrainEvalCallback,Recorder\n", " def __init__(self, device_ids): self.device_ids = device_ids\n", " def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)\n", " def after_fit(self): self.learn.model = self.learn.model.module" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def to_parallel(self: Learner, device_ids=None):\n", " self.add_cb(ParallelTrainer(device_ids))\n", " return self" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def detach_parallel(self: Learner):\n", " \"Remove ParallelTrainer callback from Learner.\"\n", " self.remove_cb(ParallelTrainer)\n", " return self" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "@contextmanager\n", "def parallel_ctx(self: Learner, device_ids=None):\n", " \"A context manager to adapt a learner to train in data parallel mode.\"\n", " try:\n", " self.to_parallel(device_ids)\n", " yield self\n", " finally:\n", " self.detach_parallel()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Distributed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Patch the parallel models so they work with RNNs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def reset(self: DistributedDataParallel):\n", " if hasattr(self.module, 'reset'): self.module.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convenience functions to set up/tear down torch distributed data parallel mode." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def setup_distrib(gpu=None):\n", " if gpu is None: return gpu\n", " gpu = int(gpu)\n", " torch.cuda.set_device(int(gpu))\n", " if num_distrib() > 1:\n", " torch.distributed.init_process_group(backend='nccl', init_method='env://')\n", " return gpu" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def teardown_distrib():\n", " if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataLoader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to change the dataloaders so that they only get one part of the batch each (otherwise there is no point in using distributed training)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args(but_as=TfmdDL.__init__)\n", "@delegates()\n", "class DistributedDL(TfmdDL):\n", "\n", " def __init__(self, dataset, rank, world_size, **kwargs):\n", " super().__init__(dataset, **kwargs)\n", " if self.n%world_size != 0: self.n += world_size-self.n%world_size\n", " self.total_n,self.n = self.n,self.n//world_size\n", " store_attr(self, 'rank,world_size')\n", "\n", " def get_idxs(self):\n", " idxs = Inf.count if self.indexed else Inf.nones\n", " return idxs if self.n is None else list(itertools.islice(idxs, self.total_n))\n", "\n", " def shuffle_fn(self, idxs):\n", " \"Deterministically shuffle on each training process based on epoch.\"\n", " g = torch.Generator()\n", " g.manual_seed(self.epoch)\n", " return L(idxs)[torch.randperm(self.total_n, generator=g)]\n", "\n", " def sample(self):\n", " idxs = self.get_idxs()\n", " if self.shuffle: idxs = self.shuffle_fn(idxs)\n", " # add extra samples to make it evenly divisible\n", " idxs += idxs[:(self.total_n - len(idxs))]\n", " # subsample\n", " idxs = idxs[self.rank:self.total_n:self.world_size]\n", " return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)\n", "\n", " def create_item(self, s):\n", " if s is not None and s >= len(self.dataset): s = s%len(self.dataset)\n", " return s if hasattr(self.dataset, 'iloc') else super().create_item(s)\n", "\n", " def set_epoch(self, epoch): self.epoch = epoch\n", "\n", " @classmethod\n", " def from_dl(cls, dl, rank, world_size, **kwargs):\n", " cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout,\n", " bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed, device=dl.device)\n", " cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in \"get_idxs sample shuffle_fn create_item\".split()})\n", " return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = TfmdDL(list(range(50)), bs=16, num_workers=2)\n", "for i in range(4):\n", " dl1 = DistributedDL.from_dl(dl, i, 4)\n", " test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)\n", "res = []\n", "for i in range(4):\n", " dl1 = DistributedDL.from_dl(dl, i, 4)\n", " dl1.set_epoch(0)\n", " res += list(dl1)[0].tolist()\n", "#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle\n", "test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args\n", "class DistributedTrainer(Callback):\n", " run_after,run_before = TrainEvalCallback,Recorder\n", " fup = None # for `find_unused_parameters` in DistributedDataParallel()\n", " def __init__(self, cuda_id=0): self.cuda_id = cuda_id\n", "\n", " def before_fit(self):\n", " opt_kwargs = { 'find_unused_parameters' : DistributedTrainer.fup } if DistributedTrainer.fup is not None else {}\n", " self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id, **opt_kwargs)\n", " self.old_dls = list(self.dls)\n", " self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls]\n", " if rank_distrib() > 0: self.learn.logger=noop\n", "\n", " def _wrap_dl(self, dl):\n", " return dl if isinstance(dl, DistributedDL) else DistributedDL.from_dl(dl, rank_distrib(), num_distrib())\n", "\n", " def before_epoch(self):\n", " for dl in self.dls: dl.set_epoch(self.epoch)\n", "\n", " def before_train(self): self.learn.dl = self._wrap_dl(self.learn.dl)\n", " def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl)\n", "\n", " def after_fit(self):\n", " self.learn.model = self.learn.model.module\n", " self.learn.dls.loaders = self.old_dls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Attach, remove a callback which adapts the model to use DistributedDL to train in distributed data parallel mode." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def to_distributed(self: Learner, cuda_id):\n", " self.add_cb(DistributedTrainer(cuda_id))\n", " if rank_distrib() > 0: self.remove_cb(ProgressCallback)\n", " return self" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def detach_distributed(self: Learner):\n", " if num_distrib() <=1: return self\n", " self.remove_cb(DistributedTrainer)\n", " if rank_distrib() > 0 and not hasattr(self, 'progress'): self.add_cb(ProgressCallback())\n", " return self" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "@contextmanager\n", "def distrib_ctx(self: Learner, cuda_id=None):\n", " \"A context manager to adapt a learner to train in distributed data parallel mode.\"\n", " # Figure out the GPU to use from rank. Create a dpg if none exists yet.\n", " if cuda_id is None: cuda_id = rank_distrib()\n", " if not torch.distributed.is_initialized():\n", " setup_distrib(cuda_id)\n", " cleanup_dpg = torch.distributed.is_initialized()\n", " else: cleanup_dpg = False\n", " # Adapt self to DistributedDataParallel, yield, and cleanup afterwards.\n", " try:\n", " if num_distrib() > 1: self.to_distributed(cuda_id)\n", " yield self\n", " finally:\n", " self.detach_distributed()\n", " if cleanup_dpg: teardown_distrib()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `distrib_ctx` context manager\n", "\n", "**`distrib_ctx(cuda_id)`** prepares a learner to train in distributed data parallel mode. It assumes these [environment variables](https://pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods) have all been setup properly, such as those launched by [`python -m fastai2.launch`](https://github.com/fastai/fastai2/blob/master/fastai2/launch.py).\n", "\n", "#### Typical usage:\n", "```\n", "with learn.distrib_ctx(): learn.fit(.....)\n", "```\n", "\n", "It attaches a `DistributedTrainer` callback and `DistributedDL` data loader to the learner, then executes `learn.fit(.....)`. Upon exiting the context, it removes the `DistributedTrainer` and `DistributedDL`, and destroys any locally created distributed process group. The process is still attached to the GPU though.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def rank0_first(func):\n", " \"Execute `func` in the Rank-0 process first, then in other ranks in parallel.\"\n", " dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0)\n", " with dummy_l.distrib_ctx():\n", " if rank_distrib() == 0: res = func()\n", " distrib_barrier()\n", " if rank_distrib() != 0: res = func()\n", " return res" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**`rank0_first(f)`** calls `f()` in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, `f()` is called only once as expected.\n", "\n", "One application of `rank0_first()` is to make fresh downloads via `untar_data()` safe in distributed training scripts launched by `python -m fastai2.launch