def __iter__()

in modules/SwissArmyTransformer/sat/data_utils/configure_data.py [0:0]


    def __iter__(self):
        iterators = [iter(d) for d in self.datasets]
        # Assume that all datasets iterate follow the seed (mprank, seed, dataloader-worker-id(0 if not used in dataloader)), auto-detect at iter()
        try:
            from sat.mpu import get_data_parallel_rank
            dp_rank = get_data_parallel_rank()
        except Exception:
            dp_rank = 0
        if self.batch_from_same_dataset:
            rng = np.random.default_rng(seed=[self.seed])
        else:
            rng = np.random.default_rng(seed=[dp_rank, self.seed])

        # sampling according to weights from streaming data
        while True:
            index = rng.choice(len(iterators), p=self.weights)
            # if stop iteration, remove the iterator
            try:
                if self.batch_from_same_dataset:
                    # we need to make sure the consecutive batch_size samples are from the same iterable dataset.
                    # but accumulate grad does not work.
                    for i in range(self.batch_size - 1):
                        yield next(iterators[index])
                yield next(iterators[index])
            except StopIteration:
                del iterators[index]
                del self.weights[index]
                if len(iterators) == 0:
                    break
                s = sum(self.weights)
                self.weights = [w / s for w in self.weights]
                from sat.helpers import print_rank0
                print_rank0(f'AlterDataset: remove a dataset, {len(iterators)} left.')