def generate_and_save_images_flickr_8k()

in scripts/calculate_fid.py [0:0]


def generate_and_save_images_flickr_8k(args):
    """
    Generate images from captions and save them to disk.
    """
    os.makedirs(args.save_path, exist_ok=True)

    logger.warning("Loading pipe")
    pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device)
    pipeline.transformer.enable_xformers_memory_efficient_attention()

    logger.warning("Loading data")
    dataset = Flickr8kDataset(args.dataset_root, args.dataset_captions_file)
    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    generator = torch.Generator(args.device).manual_seed(args.seed)

    logger.warning("Generating images")
    for batch in dataloader:
        image_names = batch[0]
        text = batch[1]

        images = pipeline(
            text,
            timesteps=args.timesteps,
            guidance_scale=args.guidance_scale,
            temperature=args.temperature,
            generator=generator,
            use_tqdm=False,
        )

        for image_name, image in zip(image_names, images):
            image.save(os.path.join(args.save_path, f"{image_name}"))