in data/generate_dpo.py [0:0]
def process(self, dataset: StepInput) -> "GeneratorStepOutput":
column_rating_map = {
"chosen": 1,
"rejected": -1,
}
target_rating = column_rating_map[self.target_column]
for batch_start in range(0, len(dataset), self.batch_size):
batch = dataset[batch_start : batch_start + self.batch_size]
filtered_batch = []
for conversation in batch:
for row in batch:
_conversation = row["conversation"]
conversation = None
for idx, message in enumerate(_conversation, 1):
if not isinstance(message["rating"], int):
continue
if message["rating"] == target_rating:
conversation = _conversation[:idx]
break
if conversation:
filtered_batch.append({"conversation": conversation})
yield filtered_batch