in trl/trainer/kto_trainer.py [0:0]
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"shuffle": False,
}
# prepare dataloader
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
reference_completion_logps = []
reference_KL_logps = []
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
reference_completion_logps.append(reference_completion_logp.cpu())
if self.calculate_KL:
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
reference_KL_logps.append(reference_KL_logp.cpu())
eval_dataset = eval_dataset.add_column(
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
)
if self.calculate_KL:
eval_dataset = eval_dataset.add_column(
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
)
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
if self.eval_dataset is not None:
self.eval_dataset = eval_dataset
self._precomputed_eval_ref_log_probs = True
return super().get_eval_dataloader(eval_dataset=eval_dataset)