in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]
def concatenated_inputs(batch: dict, padding_value: int) -> dict:
"""
Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt
and completion sequences.
"""
output = {}
output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0)
output["prompt_attention_mask"] = torch.cat([batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0)
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
output["completion_input_ids"] = torch.cat(
(
pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value),
pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value),
),
)
output["completion_attention_mask"] = torch.cat(
(
pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0),
pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0),
),
)
return output