in code/run_eval_prm_trl.py [0:0]
def __init__(self, data: list[Example], batch_size: int = 32):
self.data = data
self.batch_size = batch_size
self.current_idx = 0
# Create index mapping for steps
self.step_mapping = [] # [(dataset_idx, step_idx), ...]
for idx, item in enumerate(data):
for step_idx in range(len(item.steps)):
self.step_mapping.append((idx, step_idx))
self.total_steps = len(self.step_mapping)