def concatenated_forward()

in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]


def concatenated_forward(model: torch.nn.Module, batch: dict, max_length: int = 2048, *a, **kw) -> dict:
    """
    Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
    Avoid doing two forward passes, because it's faster for FSDP.
    """
    num_examples = batch["prompt_input_ids"].shape[0] // 2
    prompt_input_ids = batch["prompt_input_ids"]
    prompt_attention_mask = batch["prompt_attention_mask"]
    completion_input_ids = batch["completion_input_ids"]
    completion_attention_mask = batch["completion_attention_mask"]

    # Concatenate the prompt and completion inputs
    input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
    attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
    # Mask the prompt but not the completion for the loss
    loss_mask = torch.cat((torch.zeros_like(prompt_attention_mask), completion_attention_mask), dim=1)
    # Flush left to reduce the memory usage
    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

    # Truncate right and keep end
    input_ids = input_ids[:, -max_length:]
    attention_mask = attention_mask[:, -max_length:]
    loss_mask = loss_mask[:, -max_length:]

    # Forward pass
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits

    # Offset the logits by one to align with the labels
    labels = torch.roll(input_ids, shifts=-1, dims=1)
    loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()

    # Ensure logits and labels align (truncate logits if needed)
    if logits.shape[:2] != labels.shape[:2]:
        # for llava, the returned logits include the image tokens (placed before the text tokens)
        seq_len = labels.shape[1]
        logits = logits[:, -seq_len:]

    # Compute the log probabilities of the labels
    labels[~loss_mask] = 0  # dummy token
    per_token_logps = selective_log_softmax(logits, labels)
    per_token_logps[~loss_mask] = 0  # dummy token

    per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)
    all_logps = per_token_logps.sum(-1)  # dummy token ignored here

    output = {
        "chosen_logps": all_logps[:num_examples],
        "rejected_logps": all_logps[num_examples:],
    }
    return output