in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Convert to tensor
prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples]
prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids]
chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids]
rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids]
if "pixel_values" in examples[0]:
pixel_values = [torch.tensor(example["pixel_values"]) for example in examples]
if "pixel_attention_mask" in examples[0]:
pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples]
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])
# Pad
output = {}
output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left")
output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left")
output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id)
output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0)
output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id)
output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0)
if "pixel_values" in examples[0]:
output["pixel_values"] = pad(pixel_values, padding_value=0.0)
if "pixel_attention_mask" in examples[0]:
output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0)
if "image_sizes" in examples[0]:
output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples])
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
return output