in trl/trainer/kto_trainer.py [0:0]
def _compute_loss_liger(self, model, batch):
"""
Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss.
Args:
model:
The policy model used for generating log probabilities and outputs. It could be an encoder-decoder
model or a regular language model.
batch: A dictionary containing the input data and labels for the batch.
Returns:
A dictionary containing the following keys:
- "loss": The computed KTO loss for the batch.
- "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model.
- "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model.
- "chosen_logps": Log probabilities of the chosen responses from the policy model.
- "rejected_logps": Log probabilities of the rejected responses from the policy model.
- "chosen_rewards": Rewards for the chosen responses.
- "rejected_rewards": Rewards for the rejected responses.
- "kl": The KL divergence between the policy and reference models (detached).
If auxiliary loss is enabled, the dictionary will also include:
- "aux_loss": The auxiliary loss from the model outputs.
"""
policy_KL_logps = self._compute_kl_logps(model, batch)
reference_KL_logps = self._compute_kl_logps(self.ref_model, batch)
if self.calculate_KL:
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
else:
kl = torch.zeros(1).to(self.accelerator.device)
model_kwargs = (
{
"labels": batch["completion_labels"],
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
}
if self.is_encoder_decoder
else {}
)
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
if self.is_encoder_decoder:
# 1. Get encoder outputs
encoder_outputs = model.get_encoder()(
batch["completion_input_ids"],
attention_mask=batch["completion_attention_mask"],
return_dict=True,
**model_kwargs,
)
# 2. Get decoder outputs
outputs = model.get_decoder()(
input_ids=model_kwargs["decoder_input_ids"],
encoder_hidden_states=encoder_outputs.last_hidden_state,
use_cache=False,
**model_kwargs,
)
# 1. Get reference encoder outputs
ref_encoder_outputs = self.ref_model.get_encoder()(
batch["completion_input_ids"],
attention_mask=batch["completion_attention_mask"],
return_dict=True,
**model_kwargs,
)
# 2. Get reference decoder outputs
ref_outputs = self.ref_model.get_decoder()(
input_ids=model_kwargs["decoder_input_ids"],
encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
use_cache=False,
**model_kwargs,
)
else:
# skip the lm head and get the last hidden state
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs = base_model(
batch["completion_input_ids"],
attention_mask=batch["completion_attention_mask"],
use_cache=False,
**model_kwargs,
)
# reference model
if hasattr(self.ref_model, "get_decoder"):
ref_base_model = self.ref_model.get_decoder()
else:
ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name)
ref_outputs = ref_base_model(
batch["completion_input_ids"],
attention_mask=batch["completion_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()
ref_lm_head = self.ref_model.get_output_embeddings()
(
loss,
(
chosen_logps_sum,
rejected_logps_sum,
chosen_logits_sum,
rejected_logits_sum,
chosen_rewards_sum,
rejected_rewards_sum,
),
) = self.kto_loss_fn(
_input=outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
lin_weight=lm_head.weight,
target=batch["completion_labels"][:, 1:],
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device),
ref_input=ref_outputs.last_hidden_state[:, :-1]
if not self.is_encoder_decoder
else outputs.last_hidden_state,
ref_weight=ref_lm_head.weight,
ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None,
kl=kl,
)
output = {
"loss": loss,
"chosen_logits_sum": chosen_logits_sum,
"rejected_logits_sum": rejected_logits_sum,
"chosen_logps_sum": chosen_logps_sum,
"rejected_logps_sum": rejected_logps_sum,
"chosen_rewards_sum": chosen_rewards_sum,
"rejected_rewards_sum": rejected_rewards_sum,
"kl": kl,
}
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
return output