in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]
def __iter__(self):
iterators = [iter(d) for d in self.datasets]
# Assume that all datasets iterate follow the seed (mprank, seed, dataloader-worker-id(0 if not used in dataloader)), auto-detect at iter()
try:
from sat.mpu import get_data_parallel_rank
dp_rank = get_data_parallel_rank()
except Exception:
dp_rank = 0
if self.batch_from_same_dataset:
rng = np.random.default_rng(seed=[self.seed])
else:
rng = np.random.default_rng(seed=[dp_rank, self.seed])
# sampling according to weights from streaming data
while True:
index = rng.choice(len(iterators), p=self.weights)
# if stop iteration, remove the iterator
try:
if self.batch_from_same_dataset:
# we need to make sure the consecutive batch_size samples are from the same iterable dataset.
# but accumulate grad does not work.
for i in range(self.batch_size - 1):
yield next(iterators[index])
yield next(iterators[index])
except StopIteration:
del iterators[index]
del self.weights[index]
if len(iterators) == 0:
break
s = sum(self.weights)
self.weights = [w / s for w in self.weights]
from sat.helpers import print_rank0
print_rank0(f'AlterDataset: remove a dataset, {len(iterators)} left.')