in trl/trainer/kto_trainer.py [0:0]
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
if self.is_encoder_decoder:
completion_logits = self.model(
padded_batch["prompt_input_ids"],
attention_mask=padded_batch["prompt_attention_mask"],
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
labels=padded_batch["completion_labels"],
).logits
if self.calculate_KL:
KL_logits = self.model(
padded_batch["KL_prompt_input_ids"],
attention_mask=padded_batch["KL_prompt_attention_mask"],
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
labels=padded_batch["KL_completion_labels"],
).logits
else:
completion_logits = self.model(
padded_batch["completion_input_ids"],
attention_mask=padded_batch["completion_attention_mask"],
).logits
if self.calculate_KL:
KL_logits = self.model(
padded_batch["KL_completion_input_ids"],
attention_mask=padded_batch["KL_completion_attention_mask"],
).logits
else:
if self.is_encoder_decoder:
completion_logits = self.ref_model(
padded_batch["prompt_input_ids"],
attention_mask=padded_batch["prompt_attention_mask"],
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
labels=padded_batch["completion_labels"],
).logits
if self.calculate_KL:
KL_logits = self.ref_model(
padded_batch["KL_prompt_input_ids"],
attention_mask=padded_batch["KL_prompt_attention_mask"],
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
labels=padded_batch["KL_completion_labels"],
).logits
else:
completion_logits = self.ref_model(
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
).logits
if self.calculate_KL:
KL_logits = self.ref_model(
padded_batch["KL_completion_input_ids"],
attention_mask=padded_batch["KL_completion_attention_mask"],
).logits
completion_logps = self.get_batch_logps(
completion_logits,
padded_batch["completion_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if self.calculate_KL:
KL_logps = self.get_batch_logps(
KL_logits,
padded_batch["KL_completion_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
else:
KL_logps = None
return completion_logps, KL_logps