def compute_reference_log_probs()

in trl/trainer/kto_trainer.py [0:0]


    def compute_reference_log_probs(self, padded_batch: dict) -> dict:
        """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    if self.is_encoder_decoder:
                        completion_logits = self.model(
                            padded_batch["prompt_input_ids"],
                            attention_mask=padded_batch["prompt_attention_mask"],
                            decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
                            labels=padded_batch["completion_labels"],
                        ).logits

                        if self.calculate_KL:
                            KL_logits = self.model(
                                padded_batch["KL_prompt_input_ids"],
                                attention_mask=padded_batch["KL_prompt_attention_mask"],
                                decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
                                labels=padded_batch["KL_completion_labels"],
                            ).logits
                    else:
                        completion_logits = self.model(
                            padded_batch["completion_input_ids"],
                            attention_mask=padded_batch["completion_attention_mask"],
                        ).logits

                        if self.calculate_KL:
                            KL_logits = self.model(
                                padded_batch["KL_completion_input_ids"],
                                attention_mask=padded_batch["KL_completion_attention_mask"],
                            ).logits
            else:
                if self.is_encoder_decoder:
                    completion_logits = self.ref_model(
                        padded_batch["prompt_input_ids"],
                        attention_mask=padded_batch["prompt_attention_mask"],
                        decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
                        labels=padded_batch["completion_labels"],
                    ).logits

                    if self.calculate_KL:
                        KL_logits = self.ref_model(
                            padded_batch["KL_prompt_input_ids"],
                            attention_mask=padded_batch["KL_prompt_attention_mask"],
                            decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
                            labels=padded_batch["KL_completion_labels"],
                        ).logits
                else:
                    completion_logits = self.ref_model(
                        padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
                    ).logits

                    if self.calculate_KL:
                        KL_logits = self.ref_model(
                            padded_batch["KL_completion_input_ids"],
                            attention_mask=padded_batch["KL_completion_attention_mask"],
                        ).logits

        completion_logps = self.get_batch_logps(
            completion_logits,
            padded_batch["completion_labels"],
            average_log_prob=False,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        if self.calculate_KL:
            KL_logps = self.get_batch_logps(
                KL_logits,
                padded_batch["KL_completion_labels"],
                average_log_prob=False,
                is_encoder_decoder=self.is_encoder_decoder,
                label_pad_token_id=self.label_pad_token_id,
            )
        else:
            KL_logps = None

        return completion_logps, KL_logps