in trl/trainer/bco_trainer.py [0:0]
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
if self.is_encoder_decoder:
completion_logits = self.model(
padded_batch["prompt_input_ids"],
attention_mask=padded_batch["prompt_attention_mask"],
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
labels=padded_batch["completion_labels"],
).logits
else:
completion_logits = self.model(
padded_batch["completion_input_ids"],
attention_mask=padded_batch["completion_attention_mask"],
).logits
else:
if self.is_encoder_decoder:
completion_logits = self.ref_model(
padded_batch["prompt_input_ids"],
attention_mask=padded_batch["prompt_attention_mask"],
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
labels=padded_batch["completion_labels"],
).logits
else:
completion_logits = self.ref_model(
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
).logits
completion_logps = self.get_batch_logps(
completion_logits,
padded_batch["completion_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
return completion_logps