import torch
[docs]class Sampler(object):
"""Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
[docs]class SequentialSampler(Sampler):
"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(range(self.num_samples))
def __len__(self):
return self.num_samples
[docs]class RandomSampler(Sampler):
"""Samples elements randomly, without replacement.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(torch.randperm(self.num_samples).long())
def __len__(self):
return self.num_samples
[docs]class SubsetRandomSampler(Sampler):
"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (list): a list of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
[docs]class WeightedRandomSampler(Sampler):
"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
Arguments:
weights (list) : a list of weights, not necessary summing up to one
num_samples (int): number of samples to draw
"""
def __init__(self, weights, num_samples, replacement=True):
self.weights = torch.DoubleTensor(weights)
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
def __len__(self):
return self.num_samples