def main()

in validation/compare_models.py [0:0]


def main(args):
    data_root = os.path.join(f"comparison-{args.model_id}")

    print("Loading validation dataset and inference model...")
    dataset = data_utils.load_dataset(args.dataset_id, args.max_num_samples)
    using_tf = False
    if "sayakpaul" in args.model_id:
        inference = model_utils.load_model(args.model_id)
        using_tf = True
        print(
            "TensorFlow model detected for inference, Diffusion-specifc parameters won't be used."
        )
    else:
        inference = load_pipeline(args.model_id)

    num_samples_to_generate = (
        args.max_num_samples
        if args.max_num_samples is not None
        else dataset.cardinality()
    )
    print(f"Generating {num_samples_to_generate} images...")
    for sample in dataset.as_numpy_iterator():
        # Result dir creation.
        concept_path = os.path.join(data_root, str(sample["label"]))
        hash_image = hashlib.sha1(sample["image"].tobytes()).hexdigest()
        image_path = os.path.join(concept_path, hash_image)
        os.makedirs(image_path, exist_ok=True)

        # Perform inference and serialize the result.
        if using_tf:
            image = model_utils.perform_inference(inference)(sample["image"])
            Image.fromarray(sample["image"]).save(os.path.join(image_path, "original.png"))
            image.save(os.path.join(image_path, "tf_image.png"))
        else:
            image = inference(
                args.prompt,
                image=Image.fromarray(sample["image"]).convert("RGB"),
                num_inference_steps=args.num_inference_steps,
                image_guidance_scale=args.image_guidance_scale,
                guidance_scale=args.guidance_scale,
                generator=GEN,
            ).images[0]
            image_prefix = f"steps@{args.num_inference_steps}-igs@{args.image_guidance_scale}-gs@{args.guidance_scale}"
            Image.fromarray(sample["image"]).save(os.path.join(image_path, "original.png"))
            image.save(os.path.join(image_path, f"{image_prefix}.png"))