In [None]:
# default_exp distributed

In [None]:
#export
from local.basics import *
from local.callback.progress import ProgressCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.utils.data.distributed import DistributedSampler

In [None]:
from local.test import * 

# Distributed and parallel training

> Callbacks and helper functions to train in parallel or use distributed training

## Parallel

Patch the parallel models so they work with RNNs

In [None]:
#export
@patch
def reset(self: DataParallel):
 if hasattr(self.module, 'reset'): self.module.reset()

In [None]:
#export
class ParallelTrainer(Callback):
 run_after,run_before = TrainEvalCallback,Recorder
 def __init__(self, device_ids): self.device_ids = device_ids
 def begin_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)
 def after_fit(self): self.learn.model = self.learn.model.module

In [None]:
#export
@patch
def to_parallel(self: Learner, device_ids=None): 
 self.add_cb(ParallelTrainer(device_ids))
 return self

## Distributed

Patch the parallel models so they work with RNNs

In [None]:
#export
@patch
def reset(self: DistributedDataParallel):
 if hasattr(self.module, 'reset'): self.module.reset()

In [None]:
#export
def setup_distrib(gpu=None):
 if gpu is None: return gpu
 gpu = int(gpu)
 torch.cuda.set_device(int(gpu))
 if num_distrib() > 1:
 torch.distributed.init_process_group(backend='nccl', init_method='env://')
 return gpu

### DataLoader

We need to change the dataloaders so that they only get one part of the batch each (otherwise tehre is not point in using distributed training).

In [None]:
#export
@delegates()
class DistributedDL(TfmdDL):
 
 def __init__(self, dataset, rank, world_size, **kwargs):
 super().__init__(dataset, **kwargs)
 if self.n%world_size != 0: self.n += world_size-self.n%world_size
 self.total_n,self.n = self.n,self.n//world_size
 store_attr(self, 'rank,world_size')
 
 def get_idxs(self):
 idxs = Inf.count if self.indexed else Inf.nones
 return idxs if self.n is None else list(itertools.islice(idxs, self.total_n))
 
 def shuffle_fn(self, idxs):
 "Deterministically shuffle on each training process based on epoch."
 g = torch.Generator()
 g.manual_seed(self.epoch)
 return L(idxs)[torch.randperm(self.total_n, generator=g)]
 
 def sample(self):
 idxs = self.get_idxs()
 if self.shuffle: idxs = self.shuffle_fn(idxs)
 # add extra samples to make it evenly divisible
 idxs += idxs[:(self.total_n - len(idxs))]
 # subsample
 idxs = idxs[self.rank:self.total_n:self.world_size]
 return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
 
 def create_item(self, s):
 if s is not None and s >= len(self.dataset): s = s%len(self.dataset)
 return super().create_item(s)
 
 def set_epoch(self, epoch): self.epoch = epoch
 
 @classmethod
 def from_dl(cls, dl, rank, world_size, **kwargs):
 cur_kwargs = dict(num_workers=dl.fake_l.num_workers, pin_memory=dl.pin_memory, timeout=dl.timeout,
 bs=dl.bs, shuffle=dl.shuffle, drop_last=dl.drop_last, indexed=dl.indexed)
 cur_kwargs.update({n: getattr(dl, n) for n in cls._methods if n not in "sample shuffle_fn create_item".split()})
 return cls(dl.dataset, rank, world_size, **merge(cur_kwargs, kwargs))

In [None]:
dl = TfmdDL(list(range(50)), bs=16, num_workers=2)
for i in range(4):
 dl1 = DistributedDL.from_dl(dl, i, 4)
 test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)

In [None]:
dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)
res = []
for i in range(4):
 dl1 = DistributedDL.from_dl(dl, i, 4)
 dl1.set_epoch(0)
 res += list(dl1)[0].tolist()
#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle
test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))

In [None]:
#export
class DistributedTrainer(Callback):
 run_after,run_before = TrainEvalCallback,Recorder
 def __init__(self, cuda_id=0): self.cuda_id = cuda_id

 def begin_fit(self):
 self.learn.model = DistributedDataParallel(self.model, device_ids=[self.cuda_id], output_device=self.cuda_id)
 self.old_dls = [dl for dl in self.dbunch.dls]
 self.learn.dbunch.dls = [DistributedDL.from_dl(dl, rank_distrib(), num_distrib()) for dl in self.dbunch.dls]
 if rank_distrib() > 0: self.learn.logger=noop

 def begin_epoch(self): 
 for dl in self.dbunch.dls: dl.set_epoch(self.epoch)

 def after_fit(self):
 self.learn.model = self.learn.model.module
 self.learn.dbunch.dls = self.old_dls

In [None]:
#export
@patch
def to_distributed(self: Learner, cuda_id): 
 self.add_cb(DistributedTrainer(cuda_id))
 if rank_distrib() > 0: self.remove_cb(self.progress)
 return self

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_utils.ipynb.
Converted 01b_dispatch.ipynb.
Converted 01c_transform.ipynb.
Converted 02_script.ipynb.
Converted 03_torch_core.ipynb.
Converted 03a_layers.ipynb.
Converted 04_dataloader.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_data_block.ipynb.
Converted 08_vision_core.ipynb.
Converted 09_vision_augment.ipynb.
Converted 09a_vision_data.ipynb.
Converted 10_pets_tutorial.ipynb.
Converted 11_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 13a_metrics.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 14a_callback_data.ipynb.
Converted 15_callback_hook.ipynb.
Converted 15a_vision_models_unet.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision_learner.ipy