in scripts/calculate_fid.py [0:0]
def generate_and_save_images_coco(args):
os.makedirs(args.save_path, exist_ok=True)
os.makedirs(args.dataset_root, exist_ok=True)
logger.warning("Loading pipe")
pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device)
pipeline.transformer.enable_xformers_memory_efficient_attention()
logger.warning("Loading data")
# 20 shards is safe range to get 30k images
start_shard = 0
end_shard = 20
num_images_to_generate = 30_000
if args.slurm:
slurm_ntasks = int(os.environ["SLURM_NTASKS"])
slurm_procid = int(os.environ["SLURM_PROCID"])
distributed_shards = distribute_shards(start_shard, end_shard, slurm_ntasks)
start_shard, end_shard = distributed_shards[slurm_procid]
num_images_to_generate = round(num_images_to_generate / slurm_ntasks)
logger.warning("************")
logger.warning("Running as slurm task")
logger.warning(f"SLURM_NTASKS: {slurm_ntasks}")
logger.warning(f"SLURM_PROCID: {slurm_procid}")
logger.warning(f"start_shard: {start_shard}, end_shard: {end_shard}")
logger.warning("************")
logger.warning(f"all slurm processes")
for slurm_proc_id_, (proc_start_shard, proc_end_shard) in enumerate(distributed_shards):
logger.warning(
f"slurm process: {slurm_proc_id_}, start_shard: {proc_start_shard}, end_shard: {proc_end_shard}"
)
logger.warning("************")
shard_range = "{" + format_shard_number(start_shard) + ".." + format_shard_number(end_shard) + "}"
download_shards = f"pipe:aws s3 cp s3://muse-datasets/coco/2017/train/{shard_range}.tar -"
logger.warning(f"downloading shards {download_shards}")
dataset = (
wds.WebDataset(download_shards)
.decode("pil", handler=wds.warn_and_continue)
.rename(image="jpg;png;jpeg;webp", metadata="json")
.map(
lambda dict: {
"__key__": dict["__key__"],
"image": dict["image"],
"metadata": dict["metadata"],
}
)
.to_tuple("__key__", "image", "metadata")
.batched(args.batch_size)
)
dataloader = DataLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=0,
)
generator = torch.Generator(args.device).manual_seed(args.seed)
logger.warning("Generating images")
num_images_generated = 0
for __key__, real_image, metadata in dataloader:
logger.warning(f"Creating {len(__key__)} images: {__key__[0]} {__key__[-1]}")
num_images_generated += len(__key__)
text = [json.loads(x["annotations"])[0]["caption"] for x in metadata]
t0 = time.perf_counter()
generated_image = pipeline(
text,
timesteps=args.timesteps,
guidance_scale=args.guidance_scale,
temperature=args.temperature,
generator=generator,
use_tqdm=False,
)
logger.warning(f"Generation time {time.perf_counter() - t0}")
for __key__, generated_image, real_image in zip(__key__, generated_image, real_image):
real_image.save(os.path.join(args.dataset_root, f"{__key__}.png"))
generated_image.save(os.path.join(args.save_path, f"{__key__}.png"))
logger.warning(f"Generated {num_images_generated}/{num_images_to_generate}")
if num_images_generated >= num_images_to_generate:
logger.warning("done")
break