def _compute_loss_liger()

in trl/trainer/kto_trainer.py [0:0]


    def _compute_loss_liger(self, model, batch):
        """
        Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss.

        Args:
            model:
                The policy model used for generating log probabilities and outputs. It could be an encoder-decoder
                model or a regular language model.
            batch: A dictionary containing the input data and labels for the batch.

        Returns:
            A dictionary containing the following keys:
                - "loss": The computed KTO loss for the batch.
                - "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model.
                - "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model.
                - "chosen_logps": Log probabilities of the chosen responses from the policy model.
                - "rejected_logps": Log probabilities of the rejected responses from the policy model.
                - "chosen_rewards": Rewards for the chosen responses.
                - "rejected_rewards": Rewards for the rejected responses.
                - "kl": The KL divergence between the policy and reference models (detached).

            If auxiliary loss is enabled, the dictionary will also include:
                - "aux_loss": The auxiliary loss from the model outputs.
        """
        policy_KL_logps = self._compute_kl_logps(model, batch)
        reference_KL_logps = self._compute_kl_logps(self.ref_model, batch)
        if self.calculate_KL:
            kl = (policy_KL_logps - reference_KL_logps).mean().detach()
            kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
        else:
            kl = torch.zeros(1).to(self.accelerator.device)

        model_kwargs = (
            {
                "labels": batch["completion_labels"],
                "decoder_input_ids": batch.get("completion_decoder_input_ids"),
            }
            if self.is_encoder_decoder
            else {}
        )
        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        if self.is_encoder_decoder:
            # 1. Get encoder outputs
            encoder_outputs = model.get_encoder()(
                batch["completion_input_ids"],
                attention_mask=batch["completion_attention_mask"],
                return_dict=True,
                **model_kwargs,
            )
            # 2. Get decoder outputs
            outputs = model.get_decoder()(
                input_ids=model_kwargs["decoder_input_ids"],
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                use_cache=False,
                **model_kwargs,
            )
            # 1. Get reference encoder outputs
            ref_encoder_outputs = self.ref_model.get_encoder()(
                batch["completion_input_ids"],
                attention_mask=batch["completion_attention_mask"],
                return_dict=True,
                **model_kwargs,
            )
            # 2. Get reference decoder outputs
            ref_outputs = self.ref_model.get_decoder()(
                input_ids=model_kwargs["decoder_input_ids"],
                encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
                use_cache=False,
                **model_kwargs,
            )
        else:
            # skip the lm head and get the last hidden state
            if hasattr(model, "get_decoder"):
                base_model = model.get_decoder()
            else:
                base_model = getattr(model, self.args.base_model_attribute_name)
            outputs = base_model(
                batch["completion_input_ids"],
                attention_mask=batch["completion_attention_mask"],
                use_cache=False,
                **model_kwargs,
            )

            # reference model
            if hasattr(self.ref_model, "get_decoder"):
                ref_base_model = self.ref_model.get_decoder()
            else:
                ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name)
            ref_outputs = ref_base_model(
                batch["completion_input_ids"],
                attention_mask=batch["completion_attention_mask"],
                use_cache=False,
                **model_kwargs,
            )
        lm_head = model.get_output_embeddings()
        ref_lm_head = self.ref_model.get_output_embeddings()

        (
            loss,
            (
                chosen_logps_sum,
                rejected_logps_sum,
                chosen_logits_sum,
                rejected_logits_sum,
                chosen_rewards_sum,
                rejected_rewards_sum,
            ),
        ) = self.kto_loss_fn(
            _input=outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
            lin_weight=lm_head.weight,
            target=batch["completion_labels"][:, 1:],
            bias=lm_head.bias if hasattr(lm_head, "bias") else None,
            preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device),
            ref_input=ref_outputs.last_hidden_state[:, :-1]
            if not self.is_encoder_decoder
            else outputs.last_hidden_state,
            ref_weight=ref_lm_head.weight,
            ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None,
            kl=kl,
        )

        output = {
            "loss": loss,
            "chosen_logits_sum": chosen_logits_sum,
            "rejected_logits_sum": rejected_logits_sum,
            "chosen_logps_sum": chosen_logps_sum,
            "rejected_logps_sum": rejected_logps_sum,
            "chosen_rewards_sum": chosen_rewards_sum,
            "rejected_rewards_sum": rejected_rewards_sum,
            "kl": kl,
        }
        if self.aux_loss_enabled:
            output["aux_loss"] = outputs.aux_loss

        return output