def concatenated_inputs()

in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]


def concatenated_inputs(batch: dict, padding_value: int) -> dict:
    """
    Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt
    and completion sequences.
    """
    output = {}

    output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0)
    output["prompt_attention_mask"] = torch.cat([batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0)

    max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
    output["completion_input_ids"] = torch.cat(
        (
            pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value),
            pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value),
        ),
    )
    output["completion_attention_mask"] = torch.cat(
        (
            pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0),
            pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0),
        ),
    )
    return output