in ultravox/data/datasets.py [0:0]
def __iter__(self):
ds_iters = [iter(ds) for ds in self._datasets]
ds_pos = [0] * len(ds_iters)
# Find the iterator that is least far along and vend from it.
for i in range(self._total_samples):
min_fraction = 1.0
for j in range(len(ds_iters)):
iter_fraction = ds_pos[j] / self._weighted_samples[j]
if iter_fraction < min_fraction:
min_fraction = iter_fraction
iter_index = j
try:
yield next(ds_iters[iter_index])
except StopIteration:
ds_iters[iter_index] = iter(self._datasets[iter_index])
try:
yield next(ds_iters[iter_index])
except StopIteration:
warnings.warn(
f"Dataset {iter_index} is empty. num_workers is likely too high. Stopping iteration."
)
break
ds_pos[iter_index] += 1