in scripts/calculate_fid.py [0:0]
def generate_and_save_images_flickr_8k(args):
"""
Generate images from captions and save them to disk.
"""
os.makedirs(args.save_path, exist_ok=True)
logger.warning("Loading pipe")
pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device)
pipeline.transformer.enable_xformers_memory_efficient_attention()
logger.warning("Loading data")
dataset = Flickr8kDataset(args.dataset_root, args.dataset_captions_file)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
generator = torch.Generator(args.device).manual_seed(args.seed)
logger.warning("Generating images")
for batch in dataloader:
image_names = batch[0]
text = batch[1]
images = pipeline(
text,
timesteps=args.timesteps,
guidance_scale=args.guidance_scale,
temperature=args.temperature,
generator=generator,
use_tqdm=False,
)
for image_name, image in zip(image_names, images):
image.save(os.path.join(args.save_path, f"{image_name}"))