def _generate_completions()

in trl/trainer/nash_md_trainer.py [0:0]


    def _generate_completions(self, model, prompts):
        # Generate completions from the policy model.
        with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
            model_output = unwrapped_policy_for_gen_ctx.generate(
                input_ids=prompts["input_ids"],
                attention_mask=prompts["attention_mask"],
                generation_config=self.generation_config,
            )

        # Get the DDP/FSDP unwrapped version of the main model.
        # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
        policy_model_for_gmw = self.accelerator.unwrap_model(model)

        # Determine the correct reference model for GeometricMixtureWrapper.
        # This also needs to be DDP/FSDP unwrapped.
        ref_model_for_gmw: torch.nn.Module
        if self.ref_model is None:
            # No explicit ref_model is provided.
            # Use the base of the main `model` if it's a PEFT model.
            # policy_model_for_gmw is already DDP-unwrapped.
            if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
                ref_model_for_gmw = policy_model_for_gmw.get_base_model()
            else:
                # Not a PEFT model (or PEFT not available), or already a base model.
                # Use the DDP-unwrapped policy model itself as the reference.
                ref_model_for_gmw = policy_model_for_gmw
        else:
            # An explicit ref_model is provided. Unwrap it for DDP/FSDP.
            ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)

        # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
        with torch.no_grad():  # Ensure no_grad context for mixture model generation
            mixture_model = GeometricMixtureWrapper(
                model=policy_model_for_gmw,
                ref_model=ref_model_for_gmw,
                generation_config=self.generation_config,
                mixture_coef=self.mixture_coef,
                device=self.accelerator.device,
            )

            mixture_output = mixture_model.generate(
                input_ids=prompts["input_ids"],
                attention_mask=prompts["attention_mask"],
                generation_config=self.generation_config,
            )

        return model_output, mixture_output