in src/data_manager.py [0:0]
def __iter__(self):
self._ssi = self.rank*self.cpb if self.unique_cpb else 0
self._next_perm()
# -- iterations per epoch (extract batch-size samples from each class)
ipe = (self.num_classes // self.cpb if not self.unique_cpb
else self.num_classes // (self.cpb * self.world_size)) * self.batch_size
for epoch in range(self.epochs):
# -- shuffle class order
samplers = self._get_local_samplers(epoch)
subsampled_samplers = self._subsample_samplers(samplers)
counter, batch = 0, []
for i in range(ipe):
batch += list(next(subsampled_samplers))
counter += 1
if counter == self.batch_size:
yield batch
counter, batch = 0, []
if i + 1 < ipe:
subsampled_samplers = self._subsample_samplers(samplers)