in code/run_eval_prm_trl.py [0:0]
def __next__(self):
if self.current_idx >= self.total_steps:
raise StopIteration
batch_indices = []
batch_steps = []
step_count = 0
while self.current_idx < self.total_steps and step_count < self.batch_size:
dataset_idx, step_idx = self.step_mapping[self.current_idx]
batch_indices.append((dataset_idx, step_idx))
# Here the steps have to be already generated
steps = self.data[dataset_idx].get_texts
batch_steps.append(steps[step_idx])
step_count += 1
self.current_idx += 1
return batch_steps, batch_indices