in src/datasets/sampler.py [0:0]
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
extra = len(self.dataset) % (num_replicas * batch_size)
padding = num_replicas * batch_size - extra if extra > 0 else 0
total_size = len(self.dataset) + padding
self.indices = np.array_split(np.arange(total_size), self.num_replicas)[self.rank]
self.indices = self.indices % len(self.dataset)