in trl/trainer/callbacks.py [0:0]
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# When the trainer is initialized, we generate completions for the reference model.
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
# Use the reference model if available, otherwise use the initial model
model = getattr(self.trainer, "ref_model", None)
# At this point, there are two cases where `ref_model` is None:
# 1. The method doesn't require a reference model.
# 2. The method uses a reference model, but `ref_model` is set to None.
# This occurs when using PEFT, where the reference model can be obtained by simply disabling the model's adapter.
# In theory, we should disable the adapter here, but since it's zero-initialized at the start of training,
# the model behaves identically with or without the adapter.
# Therefore, there's no need to explicitly disable it at this point.
if model is None:
model = self.trainer.model_wrapped
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
self.ref_completions = _generate_completions(
prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)
# Compute initial win rate as a reference point
completions = list(zip(self.ref_completions, self.ref_completions))
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)
# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})
if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)