in samplers.py [0:0]
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle