in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]
def concatenated_forward(model: torch.nn.Module, batch: dict, max_length: int = 2048, *a, **kw) -> dict:
"""
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
Avoid doing two forward passes, because it's faster for FSDP.
"""
num_examples = batch["prompt_input_ids"].shape[0] // 2
prompt_input_ids = batch["prompt_input_ids"]
prompt_attention_mask = batch["prompt_attention_mask"]
completion_input_ids = batch["completion_input_ids"]
completion_attention_mask = batch["completion_attention_mask"]
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat((torch.zeros_like(prompt_attention_mask), completion_attention_mask), dim=1)
# Flush left to reduce the memory usage
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
# Truncate right and keep end
input_ids = input_ids[:, -max_length:]
attention_mask = attention_mask[:, -max_length:]
loss_mask = loss_mask[:, -max_length:]
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Offset the logits by one to align with the labels
labels = torch.roll(input_ids, shifts=-1, dims=1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()
# Ensure logits and labels align (truncate logits if needed)
if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]
# Compute the log probabilities of the labels
labels[~loss_mask] = 0 # dummy token
per_token_logps = selective_log_softmax(logits, labels)
per_token_logps[~loss_mask] = 0 # dummy token
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)
all_logps = per_token_logps.sum(-1) # dummy token ignored here
output = {
"chosen_logps": all_logps[:num_examples],
"rejected_logps": all_logps[num_examples:],
}
return output