def __next__()

in code/run_eval_prm_trl.py [0:0]


    def __next__(self):
        if self.current_idx >= self.total_steps:
            raise StopIteration

        batch_indices = []
        batch_steps = []
        step_count = 0

        while self.current_idx < self.total_steps and step_count < self.batch_size:
            dataset_idx, step_idx = self.step_mapping[self.current_idx]
            batch_indices.append((dataset_idx, step_idx))

            # Here the steps have to be already generated
            steps = self.data[dataset_idx].get_texts
            batch_steps.append(steps[step_idx])

            step_count += 1
            self.current_idx += 1

        return batch_steps, batch_indices