def __init__()

in code/run_eval_prm_trl.py [0:0]


    def __init__(self, data: list[Example], batch_size: int = 32):
        self.data = data
        self.batch_size = batch_size
        self.current_idx = 0

        # Create index mapping for steps
        self.step_mapping = []  # [(dataset_idx, step_idx), ...]
        for idx, item in enumerate(data):
            for step_idx in range(len(item.steps)):
                self.step_mapping.append((idx, step_idx))

        self.total_steps = len(self.step_mapping)