in dpr/utils/data_utils.py [0:0]
def iterate_ds_data(self, epoch: int = 0) -> Iterator[Tuple[List, int]]:
logger.info("rank=%d; Iteration start", self.rank)
logger.info(
"rank=%d; Multi set iteration: iteration ptr per set: %s",
self.rank,
[it.get_iteration() for it in self.iterables],
)
data_src_indices = []
iterators = []
for source, src_its in enumerate(self.max_its_pr_ds):
logger.info(
"rank=%d; Multi set iteration: source %d, batches to be taken: %s",
self.rank,
source,
src_its,
)
data_src_indices.extend([source] * src_its)
iterators.append(self.iterables[source].iterate_ds_sampled_data(src_its, epoch=epoch))
if self.shuffle:
# to be able to resume, same shuffling should be used when starting from a failed/stopped iteration
epoch_rnd = random.Random(self.shuffle_seed + epoch)
epoch_rnd.shuffle(data_src_indices)
logger.info("rank=%d; data_src_indices len=%d", self.rank, len(data_src_indices))
for i, source_idx in enumerate(data_src_indices):
it = iterators[source_idx]
next_item = next(it, None)
if next_item is not None:
self.iteration += 1
yield (next_item, source_idx)
else:
logger.warning("rank=%d; Next item in the source %s is None", self.rank, source_idx)
logger.info("rank=%d; last iteration %d", self.rank, self.iteration)
logger.info(
"rank=%d; Multi set iteration finished: iteration per set: %s",
self.rank,
[it.iteration for it in self.iterables],
)
[next(it, None) for it in iterators]
# TODO: clear iterators in some non-hacky way
for it in self.iterables:
it.iteration = 0
logger.info(
"rank=%d; Multi set iteration finished after next: iteration per set: %s",
self.rank,
[it.iteration for it in self.iterables],
)
# reset the iteration status
self.iteration = 0