in trl/trainer/dpo_trainer.py [0:0]
def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
unwrapped_model = self.accelerator.unwrap_model(model)
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value)
model_kwargs = {}
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
# Add the pixel values and attention masks for vision models
if "pixel_values" in concatenated_batch:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
if "pixel_attention_mask" in concatenated_batch:
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if "image_sizes" in concatenated_batch:
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_attention_mask = concatenated_batch["completion_attention_mask"]
if self.is_encoder_decoder:
# 1. Get encoder outputs
encoder_outputs = unwrapped_model.get_encoder()(
concatenated_batch["prompt_input_ids"],
attention_mask=concatenated_batch["prompt_attention_mask"],
return_dict=True,
)
# 2. Prepare decoder inputs
decoder_input_ids = shift_tokens_right(
concatenated_batch["completion_input_ids"],
unwrapped_model.config.decoder_start_token_id,
)
# 3. Get decoder outputs
decoder_outputs = unwrapped_model.get_decoder()(
input_ids=decoder_input_ids,
attention_mask=concatenated_batch["completion_attention_mask"],
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
use_cache=False,
)
hidden_states = decoder_outputs.last_hidden_state
ref_hidden_states = None
if not self.reference_free and self.ref_model is not None:
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
ref_encoder_outputs = unwrapped_ref_model.get_encoder()(
concatenated_batch["prompt_input_ids"],
attention_mask=concatenated_batch["prompt_attention_mask"],
return_dict=True,
)
ref_decoder_outputs = unwrapped_ref_model.get_decoder()(
input_ids=decoder_input_ids,
attention_mask=concatenated_batch["completion_attention_mask"],
encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
use_cache=False,
)
ref_hidden_states = ref_decoder_outputs.last_hidden_state
elif not self.reference_free:
with self.null_ref_context():
ref_encoder_outputs = unwrapped_model.get_encoder()(
concatenated_batch["prompt_input_ids"],
attention_mask=concatenated_batch["prompt_attention_mask"],
return_dict=True,
)
ref_decoder_outputs = unwrapped_model.get_decoder()(
input_ids=decoder_input_ids,
attention_mask=concatenated_batch["completion_attention_mask"],
encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
use_cache=False,
)
ref_hidden_states = ref_decoder_outputs.last_hidden_state
labels = concatenated_batch["completion_input_ids"]
loss_mask = completion_attention_mask.bool()
else:
# For decoder-only models
input_ids = torch.cat(
(concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1
)
attention_mask = torch.cat(
(concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]),
dim=1,
)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
dim=1,
)
# Flush and truncate
if self.max_length is not None and self.max_length < attention_mask.size(1):
if self.truncation_mode == "keep_start":
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
attention_mask = attention_mask[:, : self.max_length]
input_ids = input_ids[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
elif self.truncation_mode == "keep_end":
# Flush right before truncating left, then flush left
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
# [0, x, x, x, 0, 0]] [0, x, x, x]]
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
"'keep_start']."
)
else:
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
# Add logits_to_keep optimization
if self.use_logits_to_keep:
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1
model_kwargs["logits_to_keep"] = logits_to_keep
model_kwargs["output_hidden_states"] = True
# Add padding-free training support
if self.padding_free:
input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
model_kwargs["position_ids"] = position_ids
else:
model_kwargs["attention_mask"] = attention_mask
# Get the base model outputs (before LM head)
if hasattr(unwrapped_model, "get_decoder"):
base_model = unwrapped_model.get_decoder()
else:
base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model)
outputs = base_model(
input_ids,
use_cache=False,
**model_kwargs,
)
hidden_states = outputs.last_hidden_state[:, :-1]
# Get reference hidden states if needed
ref_hidden_states = None
if not self.reference_free and self.ref_model is not None:
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
if hasattr(unwrapped_ref_model, "get_decoder"):
ref_base_model = unwrapped_ref_model.get_decoder()
else:
ref_base_model = getattr(
unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model
)
ref_outputs = ref_base_model(
input_ids,
use_cache=False,
**model_kwargs,
)
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
elif not self.reference_free:
if hasattr(unwrapped_model, "get_decoder"):
ref_base_model = unwrapped_model.get_decoder()
else:
ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model)
with self.null_ref_context():
ref_outputs = ref_base_model(
input_ids,
attention_mask=attention_mask,
use_cache=False,
**model_kwargs,
)
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id)
labels = masked_input_ids[:, 1:] # Shift right for casual LM
# Get the LM head
lm_head = unwrapped_model.get_output_embeddings()
# Get reference model weights if needed
ref_weight = None
ref_bias = None
if not self.reference_free:
if self.ref_model is not None:
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
ref_lm_head = unwrapped_ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = unwrapped_model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
# Compute loss using Liger kernel
loss_output = self.dpo_loss_fn(
lm_head.weight,
hidden_states,
labels,
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
ref_input=ref_hidden_states if not self.reference_free else None,
ref_weight=ref_weight if not self.reference_free else None,
ref_bias=ref_bias if not self.reference_free else None,
)
(
loss,
(chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs),
) = loss_output
output = {
"loss": loss,
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps,
"mean_chosen_logits": chosen_logits_mean,
"mean_rejected_logits": rejected_logits_mean,
"nll_loss": nll_loss,
"chosen_rewards": aux_outputs[0],
"rejected_rewards": aux_outputs[1],
}
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
return output