in dpr/utils/data_utils.py [0:0]
def iterate_ds_data(self, epoch: int = 0) -> Iterator[List]:
# if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations
max_iterations = self.max_iterations - self.iteration
shard_indices = self.get_shard_indices(epoch)
for i in range(self.iteration * self.batch_size, len(shard_indices), self.batch_size):
items_idxs = shard_indices[i : i + self.batch_size]
if self.strict_batch_size and len(items_idxs) < self.batch_size:
logger.debug("Extending batch to max size")
items_idxs.extend(shard_indices[0 : self.batch_size - len(items)])
self.iteration += 1
items = [self.data[idx] for idx in items_idxs]
yield items
# some shards may done iterating while the others are at the last batch. Just return the first batch
while self.iteration < max_iterations:
logger.debug("Fulfilling non complete shard=".format(self.shard_id))
self.iteration += 1
items_idxs = shard_indices[0 : self.batch_size]
items = [self.data[idx] for idx in items_idxs]
yield items
logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id))
# reset the iteration status
self.iteration = 0