in trl/trainer/online_dpo_trainer.py [0:0]
def _generate_vllm(self, model, prompts):
eos_token_id = self.processing_class.eos_token_id
pad_token_id = self.processing_class.pad_token_id
# Load the latest weights
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(model.state_dict().items())
if is_conversational({"prompt": prompts[0]}):
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False)
else:
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False)
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
# Create mask and pad the prompt and completion
max_prompt_length = max(len(ids) for ids in prompt_ids)
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
max_tokens = self.generation_config.max_tokens
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
completion_ids = [
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
for ids in completion_ids
]
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
# Convert to tensors
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
return prompt_ids, prompt_mask, completion_ids, completion_mask