in seamseg/data/sampler.py [0:0]
def _generate_batches(self):
g = torch.Generator()
g.manual_seed(self._epoch)
# Shuffle the two sets separately
self.img_sets[0] = [self.img_sets[0][i] for i in torch.randperm(len(self.img_sets[0]), generator=g)]
self.img_sets[1] = [self.img_sets[1][i] for i in torch.randperm(len(self.img_sets[1]), generator=g)]
batches = []
leftover = []
for img_set in self.img_sets:
batch = []
for img in img_set:
batch.append(img)
if len(batch) == self.batch_size:
batches.append(batch)
batch = []
leftover += batch
if not self.drop_last:
batch = []
for img in leftover:
batch.append(img)
if len(batch) == self.batch_size:
batches.append(batch)
batch = []
if len(batch) != 0:
batches.append(batch)
return batches