in trl/trainer/alignprop_trainer.py [0:0]
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
"""
Generate samples from the model
Args:
batch_size (int): Batch size to use for sampling
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
Returns:
prompt_image_pairs (dict[Any])
"""
prompt_image_pairs = {}
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
if prompts is None:
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
else:
prompt_metadata = [{} for _ in range(batch_size)]
prompt_ids = self.sd_pipeline.tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
if with_grad:
sd_output = self.sd_pipeline.rgb_with_grad(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=self.config.sample_num_steps,
guidance_scale=self.config.sample_guidance_scale,
eta=self.config.sample_eta,
truncated_backprop_rand=self.config.truncated_backprop_rand,
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
output_type="pt",
)
else:
sd_output = self.sd_pipeline(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=self.config.sample_num_steps,
guidance_scale=self.config.sample_guidance_scale,
eta=self.config.sample_eta,
output_type="pt",
)
images = sd_output.images
prompt_image_pairs["images"] = images
prompt_image_pairs["prompts"] = prompts
prompt_image_pairs["prompt_metadata"] = prompt_metadata
return prompt_image_pairs