def generate_and_save_images_coco()

in scripts/calculate_fid.py [0:0]


def generate_and_save_images_coco(args):
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.dataset_root, exist_ok=True)

    logger.warning("Loading pipe")
    pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device)
    pipeline.transformer.enable_xformers_memory_efficient_attention()

    logger.warning("Loading data")

    # 20 shards is safe range to get 30k images
    start_shard = 0
    end_shard = 20
    num_images_to_generate = 30_000

    if args.slurm:
        slurm_ntasks = int(os.environ["SLURM_NTASKS"])
        slurm_procid = int(os.environ["SLURM_PROCID"])

        distributed_shards = distribute_shards(start_shard, end_shard, slurm_ntasks)

        start_shard, end_shard = distributed_shards[slurm_procid]
        num_images_to_generate = round(num_images_to_generate / slurm_ntasks)

        logger.warning("************")
        logger.warning("Running as slurm task")
        logger.warning(f"SLURM_NTASKS: {slurm_ntasks}")
        logger.warning(f"SLURM_PROCID: {slurm_procid}")
        logger.warning(f"start_shard: {start_shard}, end_shard: {end_shard}")
        logger.warning("************")
        logger.warning(f"all slurm processes")
        for slurm_proc_id_, (proc_start_shard, proc_end_shard) in enumerate(distributed_shards):
            logger.warning(
                f"slurm process: {slurm_proc_id_}, start_shard: {proc_start_shard}, end_shard: {proc_end_shard}"
            )
        logger.warning("************")

    shard_range = "{" + format_shard_number(start_shard) + ".." + format_shard_number(end_shard) + "}"
    download_shards = f"pipe:aws s3 cp s3://muse-datasets/coco/2017/train/{shard_range}.tar -"

    logger.warning(f"downloading shards {download_shards}")

    dataset = (
        wds.WebDataset(download_shards)
        .decode("pil", handler=wds.warn_and_continue)
        .rename(image="jpg;png;jpeg;webp", metadata="json")
        .map(
            lambda dict: {
                "__key__": dict["__key__"],
                "image": dict["image"],
                "metadata": dict["metadata"],
            }
        )
        .to_tuple("__key__", "image", "metadata")
        .batched(args.batch_size)
    )
    dataloader = DataLoader(
        dataset,
        batch_size=None,
        shuffle=False,
        num_workers=0,
    )

    generator = torch.Generator(args.device).manual_seed(args.seed)

    logger.warning("Generating images")

    num_images_generated = 0

    for __key__, real_image, metadata in dataloader:
        logger.warning(f"Creating {len(__key__)} images: {__key__[0]} {__key__[-1]}")
        num_images_generated += len(__key__)

        text = [json.loads(x["annotations"])[0]["caption"] for x in metadata]

        t0 = time.perf_counter()

        generated_image = pipeline(
            text,
            timesteps=args.timesteps,
            guidance_scale=args.guidance_scale,
            temperature=args.temperature,
            generator=generator,
            use_tqdm=False,
        )

        logger.warning(f"Generation time {time.perf_counter() - t0}")

        for __key__, generated_image, real_image in zip(__key__, generated_image, real_image):
            real_image.save(os.path.join(args.dataset_root, f"{__key__}.png"))
            generated_image.save(os.path.join(args.save_path, f"{__key__}.png"))

        logger.warning(f"Generated {num_images_generated}/{num_images_to_generate}")

        if num_images_generated >= num_images_to_generate:
            logger.warning("done")
            break