scripts/calculate_fid.py (193 lines of code) (raw):

import json import logging import os import time from argparse import ArgumentParser import pandas as pd import torch import webdataset as wds from torch.utils.data import DataLoader, Dataset try: from cleanfid import fid except: raise ImportError("Please install cleanfid: pip install clean_fid") from muse import PipelineMuse logger = logging.getLogger(__name__) class Flickr8kDataset(Dataset): def __init__(self, root_dir, captions_file): self.root_dir = root_dir self.captions_file = captions_file df = pd.read_csv(captions_file, sep="\t", names=["image_name", "caption"]) df["image_name"] = df["image_name"].apply(lambda name: name.split("#")[0]) self.images = df["image_name"].unique().tolist() self.captions = [df[df["image_name"] == name]["caption"].tolist()[0] for name in self.images] def __len__(self): return len(self.images) def __getitem__(self, idx): return self.images[idx], self.captions[idx] def generate_and_save_images_flickr_8k(args): """ Generate images from captions and save them to disk. """ os.makedirs(args.save_path, exist_ok=True) logger.warning("Loading pipe") pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device) pipeline.transformer.enable_xformers_memory_efficient_attention() logger.warning("Loading data") dataset = Flickr8kDataset(args.dataset_root, args.dataset_captions_file) dataloader = DataLoader(dataset, batch_size=args.batch_size) generator = torch.Generator(args.device).manual_seed(args.seed) logger.warning("Generating images") for batch in dataloader: image_names = batch[0] text = batch[1] images = pipeline( text, timesteps=args.timesteps, guidance_scale=args.guidance_scale, temperature=args.temperature, generator=generator, use_tqdm=False, ) for image_name, image in zip(image_names, images): image.save(os.path.join(args.save_path, f"{image_name}")) def distribute_shards(start_shard_all, end_shard_all, slurm_ntasks): total_shards = end_shard_all - start_shard_all + 1 shards_per_task = total_shards // slurm_ntasks shards_per_task = [shards_per_task] * slurm_ntasks # to distribute the remainder of tasks for non-evenly divisible number of shards left_over_shards = total_shards % slurm_ntasks for slurm_procid in range(left_over_shards): shards_per_task[slurm_procid] += 1 assert sum(shards_per_task) == total_shards distributed_shards = [] for slurm_procid in range(len(shards_per_task)): if slurm_procid == 0: start_shard = start_shard_all else: start_shard = distributed_shards[slurm_procid - 1][1] + 1 end_shard = start_shard + shards_per_task[slurm_procid] - 1 distributed_shards.append((start_shard, end_shard)) assert sum([end_shard - start_shard + 1 for start_shard, end_shard in distributed_shards]) == total_shards return distributed_shards def format_shard_number(shard_n: int): return "{:0>{}}".format(shard_n, 5) def generate_and_save_images_coco(args): os.makedirs(args.save_path, exist_ok=True) os.makedirs(args.dataset_root, exist_ok=True) logger.warning("Loading pipe") pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device) pipeline.transformer.enable_xformers_memory_efficient_attention() logger.warning("Loading data") # 20 shards is safe range to get 30k images start_shard = 0 end_shard = 20 num_images_to_generate = 30_000 if args.slurm: slurm_ntasks = int(os.environ["SLURM_NTASKS"]) slurm_procid = int(os.environ["SLURM_PROCID"]) distributed_shards = distribute_shards(start_shard, end_shard, slurm_ntasks) start_shard, end_shard = distributed_shards[slurm_procid] num_images_to_generate = round(num_images_to_generate / slurm_ntasks) logger.warning("************") logger.warning("Running as slurm task") logger.warning(f"SLURM_NTASKS: {slurm_ntasks}") logger.warning(f"SLURM_PROCID: {slurm_procid}") logger.warning(f"start_shard: {start_shard}, end_shard: {end_shard}") logger.warning("************") logger.warning(f"all slurm processes") for slurm_proc_id_, (proc_start_shard, proc_end_shard) in enumerate(distributed_shards): logger.warning( f"slurm process: {slurm_proc_id_}, start_shard: {proc_start_shard}, end_shard: {proc_end_shard}" ) logger.warning("************") shard_range = "{" + format_shard_number(start_shard) + ".." + format_shard_number(end_shard) + "}" download_shards = f"pipe:aws s3 cp s3://muse-datasets/coco/2017/train/{shard_range}.tar -" logger.warning(f"downloading shards {download_shards}") dataset = ( wds.WebDataset(download_shards) .decode("pil", handler=wds.warn_and_continue) .rename(image="jpg;png;jpeg;webp", metadata="json") .map( lambda dict: { "__key__": dict["__key__"], "image": dict["image"], "metadata": dict["metadata"], } ) .to_tuple("__key__", "image", "metadata") .batched(args.batch_size) ) dataloader = DataLoader( dataset, batch_size=None, shuffle=False, num_workers=0, ) generator = torch.Generator(args.device).manual_seed(args.seed) logger.warning("Generating images") num_images_generated = 0 for __key__, real_image, metadata in dataloader: logger.warning(f"Creating {len(__key__)} images: {__key__[0]} {__key__[-1]}") num_images_generated += len(__key__) text = [json.loads(x["annotations"])[0]["caption"] for x in metadata] t0 = time.perf_counter() generated_image = pipeline( text, timesteps=args.timesteps, guidance_scale=args.guidance_scale, temperature=args.temperature, generator=generator, use_tqdm=False, ) logger.warning(f"Generation time {time.perf_counter() - t0}") for __key__, generated_image, real_image in zip(__key__, generated_image, real_image): real_image.save(os.path.join(args.dataset_root, f"{__key__}.png")) generated_image.save(os.path.join(args.save_path, f"{__key__}.png")) logger.warning(f"Generated {num_images_generated}/{num_images_to_generate}") if num_images_generated >= num_images_to_generate: logger.warning("done") break def main(args): if args.do in ["full", "generate_and_save_images"]: if args.dataset == "flickr_8k": generate_and_save_images_flickr_8k(args) elif args.dataset == "coco": generate_and_save_images_coco(args) else: assert False if args.do in ["full", "compute_fid"]: real_images = args.dataset_root generated_images = args.save_path logger.warning("computing FiD") score_clean = fid.compute_fid(real_images, generated_images, mode="clean", num_workers=0) logger.warning(f"clean-fid score is {score_clean:.3f}") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--model_name_or_path", type=str, required=False) parser.add_argument("--dataset_root", type=str, required=True) parser.add_argument("--dataset_captions_file", type=str, required=False) parser.add_argument("--save_path", type=str, required=True) parser.add_argument("--timesteps", type=int, default=12) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--guidance_scale", type=float, default=8.0) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--seed", type=int, default=2028) parser.add_argument("--dataset", type=str, default="flickr_8k", choices=("flickr_8k", "coco")) parser.add_argument("--do", type=str, default="full", choices=("full", "generate_and_save_images", "compute_fid")) parser.add_argument( "--slurm", action="store_true", help="Set when running as a slurm job to distribute coco image generation among multiple GPUs", ) args = parser.parse_args() if args.do in ["full", "generated_and_save_images"]: if args.dataset == "flickr_8k" and args.dataset_captions_file is None: raise ValueError("`--dataset=flickr_8k` requires setting `--dataset_captions_file`") if args.model_name_or_path is None: raise ValueError("`--do=full|generate_and_save_images` requires setting `--model_name_or_path`") if args.do == "full": logger.warning("generating images and calculating fid") elif args.do == "generate_and_save_images": logger.warning("just generating and saving images") elif args.do == "compute_fid": logger.warning("just computing fid") else: assert False main(args)