def main()

in scripts/gen_sdxl_synthetic_dataset.py [0:0]


def main():
    args = argparse.ArgumentParser()
    args.add_argument("--slurm", action="store_true")
    args.add_argument("--n_shards_to_write", required=True, type=int)
    args = args.parse_args()

    if args.slurm:
        slurm_procid = int(os.environ["SLURM_PROCID"])
        # `1 +` because we already used caption shard 0 while doing testing
        caption_shard_n = 1 + slurm_procid
    else:
        caption_shard_n = 0

    device = "cuda"

    clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
    clip.to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
    )
    pipe.load_lora_weights(
        "stabilityai/stable-diffusion-xl-base-1.0",
        weight_name="sd_xl_offset_example-lora_1.0.safetensors",
    )
    pipe.to(device)
    pipe.fuse_lora(lora_scale=0.4)
    pipe.enable_xformers_memory_efficient_attention()
    pipe.vae.enable_slicing()

    refiner = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-refiner-1.0",
        text_encoder_2=pipe.text_encoder_2,
        vae=pipe.vae,
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    ).to("cuda")

    captions = get_captions(caption_shard_n)
    captions = take_up_to(captions, 500)

    for shard_n, captions_ in enumerate(captions):
        t0 = time.perf_counter()

        logger.warning(f"shard_n {shard_n}")

        writer = wds.TarWriter(
            "pipe:aws s3 cp -"
            f" s3://muse-datasets/sdxl-synthetic-dataset/{caption_shard_n}/{format_shard_number(shard_n)}.tar"
        )

        key = 0

        for caption_batch_idx, captions__ in enumerate(split_list(captions_, 8)):
            logger.warning(f"caption_batch_idx {caption_batch_idx}")

            num_inference_steps = 35
            num_images_per_prompt = 4
            proportion_base_model = 0.8

            images = pipe(
                prompt=captions__,
                num_inference_steps=num_inference_steps,
                num_images_per_prompt=num_images_per_prompt,
                denoising_end=proportion_base_model,
                output_type="latent",
            ).images
            images = refiner(
                prompt=captions__,
                num_inference_steps=num_inference_steps,
                denoising_start=proportion_base_model,
                image=images,
                num_images_per_prompt=num_images_per_prompt,
            ).images

            for caption, images_ in zip(captions__, split_list(images, 4)):
                # TODO - can we avoid syncing images to cpu
                input = clip_processor(text=caption, images=images_, return_tensors="pt", padding="max_length", max_length=77, truncation=True)
                input["pixel_values"] = input["pixel_values"].to(dtype=torch.float16, device=device)
                input["input_ids"] = input["input_ids"].to(device)
                input["attention_mask"] = input["attention_mask"].to(device)

                clip_scores = clip(**input).logits_per_image.flatten().tolist()
                clip_scores = [str(x) for x in clip_scores]
                clip_scores = ",".join(clip_scores)

                logger.warning(f"__key__ {key}")

                writer.write(
                    {
                        "__key__": format_shard_number(key),
                        "0.png": images_[0],
                        "1.png": images_[1],
                        "2.png": images_[2],
                        "3.png": images_[3],
                        "txt": caption,
                        "clip_scores.txt": clip_scores,
                    }
                )

                key += 1

        writer.close()

        logger.warning(f"shard_n {shard_n} {time.perf_counter() - t0}")

        if shard_n + 1 > args.n_shards_to_write:
            break