in trl/trainer/nash_md_trainer.py [0:0]
def _generate_completions(self, model, prompts):
# Generate completions from the policy model.
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
model_output = unwrapped_policy_for_gen_ctx.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
# Get the DDP/FSDP unwrapped version of the main model.
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
policy_model_for_gmw = self.accelerator.unwrap_model(model)
# Determine the correct reference model for GeometricMixtureWrapper.
# This also needs to be DDP/FSDP unwrapped.
ref_model_for_gmw: torch.nn.Module
if self.ref_model is None:
# No explicit ref_model is provided.
# Use the base of the main `model` if it's a PEFT model.
# policy_model_for_gmw is already DDP-unwrapped.
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
else:
# Not a PEFT model (or PEFT not available), or already a base model.
# Use the DDP-unwrapped policy model itself as the reference.
ref_model_for_gmw = policy_model_for_gmw
else:
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
with torch.no_grad(): # Ensure no_grad context for mixture model generation
mixture_model = GeometricMixtureWrapper(
model=policy_model_for_gmw,
ref_model=ref_model_for_gmw,
generation_config=self.generation_config,
mixture_coef=self.mixture_coef,
device=self.accelerator.device,
)
mixture_output = mixture_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
return model_output, mixture_output