Source code for torch.utils.data.sampler

import torch


[docs]class Sampler(object): """Base class for all Samplers. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """ def __init__(self, data_source): pass def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError
[docs]class SequentialSampler(Sampler): """Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.num_samples = len(data_source) def __iter__(self): return iter(range(self.num_samples)) def __len__(self): return self.num_samples
[docs]class RandomSampler(Sampler): """Samples elements randomly, without replacement. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.num_samples = len(data_source) def __iter__(self): return iter(torch.randperm(self.num_samples).long()) def __len__(self): return self.num_samples
[docs]class SubsetRandomSampler(Sampler): """Samples elements randomly from a given list of indices, without replacement. Arguments: indices (list): a list of indices """ def __init__(self, indices): self.indices = indices def __iter__(self): return (self.indices[i] for i in torch.randperm(len(self.indices))) def __len__(self): return len(self.indices)
[docs]class WeightedRandomSampler(Sampler): """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). Arguments: weights (list) : a list of weights, not necessary summing up to one num_samples (int): number of samples to draw """ def __init__(self, weights, num_samples, replacement=True): self.weights = torch.DoubleTensor(weights) self.num_samples = num_samples self.replacement = replacement def __iter__(self): return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) def __len__(self): return self.num_samples