def _compute_loss_liger()

in trl/trainer/dpo_trainer.py [0:0]


    def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
        unwrapped_model = self.accelerator.unwrap_model(model)
        concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value)

        model_kwargs = {}
        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        # Add the pixel values and attention masks for vision models
        if "pixel_values" in concatenated_batch:
            model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
        if "pixel_attention_mask" in concatenated_batch:
            model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
        if "image_sizes" in concatenated_batch:
            model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]

        prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
        completion_attention_mask = concatenated_batch["completion_attention_mask"]

        if self.is_encoder_decoder:
            # 1. Get encoder outputs
            encoder_outputs = unwrapped_model.get_encoder()(
                concatenated_batch["prompt_input_ids"],
                attention_mask=concatenated_batch["prompt_attention_mask"],
                return_dict=True,
            )
            # 2. Prepare decoder inputs
            decoder_input_ids = shift_tokens_right(
                concatenated_batch["completion_input_ids"],
                unwrapped_model.config.decoder_start_token_id,
            )
            # 3. Get decoder outputs
            decoder_outputs = unwrapped_model.get_decoder()(
                input_ids=decoder_input_ids,
                attention_mask=concatenated_batch["completion_attention_mask"],
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
                use_cache=False,
            )
            hidden_states = decoder_outputs.last_hidden_state

            ref_hidden_states = None
            if not self.reference_free and self.ref_model is not None:
                unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
                ref_encoder_outputs = unwrapped_ref_model.get_encoder()(
                    concatenated_batch["prompt_input_ids"],
                    attention_mask=concatenated_batch["prompt_attention_mask"],
                    return_dict=True,
                )
                ref_decoder_outputs = unwrapped_ref_model.get_decoder()(
                    input_ids=decoder_input_ids,
                    attention_mask=concatenated_batch["completion_attention_mask"],
                    encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
                    encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
                    use_cache=False,
                )
                ref_hidden_states = ref_decoder_outputs.last_hidden_state
            elif not self.reference_free:
                with self.null_ref_context():
                    ref_encoder_outputs = unwrapped_model.get_encoder()(
                        concatenated_batch["prompt_input_ids"],
                        attention_mask=concatenated_batch["prompt_attention_mask"],
                        return_dict=True,
                    )
                    ref_decoder_outputs = unwrapped_model.get_decoder()(
                        input_ids=decoder_input_ids,
                        attention_mask=concatenated_batch["completion_attention_mask"],
                        encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
                        encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
                        use_cache=False,
                    )
                    ref_hidden_states = ref_decoder_outputs.last_hidden_state

            labels = concatenated_batch["completion_input_ids"]
            loss_mask = completion_attention_mask.bool()
        else:
            # For decoder-only models
            input_ids = torch.cat(
                (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1
            )
            attention_mask = torch.cat(
                (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]),
                dim=1,
            )
            # Mask the prompt but not the completion for the loss
            loss_mask = torch.cat(
                (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
                dim=1,
            )

            # Flush and truncate
            if self.max_length is not None and self.max_length < attention_mask.size(1):
                if self.truncation_mode == "keep_start":
                    # Flush left to reduce the memory usage
                    # [[0, 0, x, x, x, x],  ->  [[x, x, x, x],
                    #  [0, x, x, x, 0, 0]]       [x, x, x, 0]]
                    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
                    attention_mask = attention_mask[:, : self.max_length]
                    input_ids = input_ids[:, : self.max_length]
                    loss_mask = loss_mask[:, : self.max_length]
                elif self.truncation_mode == "keep_end":
                    # Flush right before truncating left, then flush left
                    # [[0, 0, x, x, x, x],  ->  [[0, 0, x, x],
                    #  [0, x, x, x, 0, 0]]       [0, x, x, x]]
                    attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
                    input_ids = input_ids[:, -self.max_length :]
                    attention_mask = attention_mask[:, -self.max_length :]
                    loss_mask = loss_mask[:, -self.max_length :]
                    attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
                else:
                    raise ValueError(
                        f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
                        "'keep_start']."
                    )
            else:
                # Flush left to reduce the memory usage
                # [[0, 0, x, x, x, x],  ->  [[x, x, x, x],
                #  [0, x, x, x, 0, 0]]       [x, x, x, 0]]
                attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

            # Add logits_to_keep optimization
            if self.use_logits_to_keep:
                first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
                logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1
                model_kwargs["logits_to_keep"] = logits_to_keep

            model_kwargs["output_hidden_states"] = True

            # Add padding-free training support
            if self.padding_free:
                input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
                loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
                position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
                model_kwargs["position_ids"] = position_ids
            else:
                model_kwargs["attention_mask"] = attention_mask

            # Get the base model outputs (before LM head)
            if hasattr(unwrapped_model, "get_decoder"):
                base_model = unwrapped_model.get_decoder()
            else:
                base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model)

            outputs = base_model(
                input_ids,
                use_cache=False,
                **model_kwargs,
            )
            hidden_states = outputs.last_hidden_state[:, :-1]

            # Get reference hidden states if needed
            ref_hidden_states = None
            if not self.reference_free and self.ref_model is not None:
                unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
                if hasattr(unwrapped_ref_model, "get_decoder"):
                    ref_base_model = unwrapped_ref_model.get_decoder()
                else:
                    ref_base_model = getattr(
                        unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model
                    )

                ref_outputs = ref_base_model(
                    input_ids,
                    use_cache=False,
                    **model_kwargs,
                )
                ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
            elif not self.reference_free:
                if hasattr(unwrapped_model, "get_decoder"):
                    ref_base_model = unwrapped_model.get_decoder()
                else:
                    ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model)
                with self.null_ref_context():
                    ref_outputs = ref_base_model(
                        input_ids,
                        attention_mask=attention_mask,
                        use_cache=False,
                        **model_kwargs,
                    )
                    ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]

            masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id)
            labels = masked_input_ids[:, 1:]  # Shift right for casual LM

        # Get the LM head
        lm_head = unwrapped_model.get_output_embeddings()

        # Get reference model weights if needed
        ref_weight = None
        ref_bias = None
        if not self.reference_free:
            if self.ref_model is not None:
                unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
                ref_lm_head = unwrapped_ref_model.get_output_embeddings()
            else:
                with self.null_ref_context():
                    ref_lm_head = unwrapped_model.get_output_embeddings()
            ref_weight = ref_lm_head.weight
            ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None

        # Compute loss using Liger kernel
        loss_output = self.dpo_loss_fn(
            lm_head.weight,
            hidden_states,
            labels,
            bias=lm_head.bias if hasattr(lm_head, "bias") else None,
            ref_input=ref_hidden_states if not self.reference_free else None,
            ref_weight=ref_weight if not self.reference_free else None,
            ref_bias=ref_bias if not self.reference_free else None,
        )
        (
            loss,
            (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs),
        ) = loss_output

        output = {
            "loss": loss,
            "chosen_logps": chosen_logps,
            "rejected_logps": rejected_logps,
            "mean_chosen_logits": chosen_logits_mean,
            "mean_rejected_logits": rejected_logits_mean,
            "nll_loss": nll_loss,
            "chosen_rewards": aux_outputs[0],
            "rejected_rewards": aux_outputs[1],
        }
        if self.aux_loss_enabled:
            output["aux_loss"] = outputs.aux_loss

        return output