def __iter__()

in chatlearn/data/sampler.py [0:0]


    def __iter__(self):
        self.remainder = self.sample_per_episode % self.data_parallel_size
        batch_size_list = [self.sample_per_episode // self.data_parallel_size + 1] * self.remainder + \
            [self.sample_per_episode // self.data_parallel_size] * (self.data_parallel_size - self.remainder)
        start, end = sum(batch_size_list[:self.data_parallel_rank]), sum(batch_size_list[:self.data_parallel_rank + 1])

        if self.is_eval:
            assert self.sample_per_episode <= sum(self.dataset_sizes), "eval dataset size must be larger than sample_per_episode"
            idxes = []
            for dataset_idx, dataset_size in enumerate(self.dataset_sizes):
                idxes.extend([(dataset_idx, j, (len(idxes) + j)) for j in range(dataset_size)])
            for i in range(0, len(idxes), self.sample_per_episode):
                episode_samples = idxes[i : i + self.sample_per_episode]
                yield episode_samples[start : end]
        else:
            num_samples = self.sample_per_episode // self.num_inference_per_prompt
            while True:
                if self.drop_last == "drop":
                    batch = self.samplers[0].get_next(num_samples)
                    batch_idxes = [(0, batch[i], i) for i in range(len(batch))]
                    batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
                    batch_idxes = batch_idxes[start : end]
                elif self.drop_last == "retain":
                    batch = self.samplers[0].get_next(num_samples)
                    batch_idxes = [(0, batch[i], i) for i in range(len(batch))]
                    batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
                    new_batch_size_list = [len(batch_idxes) // self.data_parallel_size] * self.data_parallel_size
                    for i in range(len(batch_idxes) % self.data_parallel_size):
                        new_batch_size_list[i] += 1
                    batch_idxes = \
                        batch_idxes[sum(new_batch_size_list[:self.data_parallel_rank]) : sum(new_batch_size_list[:self.data_parallel_rank + 1])]
                else:
                    batch_idxes = []
                    dataset_id = 0
                    while len(batch_idxes) != num_samples:
                        data_num = min(self.dataset_remains[dataset_id], num_samples - len(batch_idxes))
                        self.dataset_remains[dataset_id] -= data_num
                        batch = self.samplers[dataset_id].get_next(data_num)
                        batch_idxes.extend([(dataset_id, batch[i], (len(batch_idxes) + i)) for i in range(data_num)])
                        dataset_id = (dataset_id + 1) % self.dataset_num
                        if self.dataset_remains == [0] * self.dataset_num:
                            self.dataset_remains = self.data_ratio[:]
                    batch_idxes = self.repeat(batch_idxes, interleave=not self.data_rerank)
                    batch_idxes = batch_idxes[start : end]
                yield batch_idxes