validation/compare_models.py (85 lines of code) (raw):
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
import argparse
import hashlib
import os
import data_utils
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
from PIL import Image
from data_preparation import model_utils
GEN = torch.manual_seed(0)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
default="sayakpaul/whitebox-cartoonizer",
choices=[
"sayakpaul/whitebox-cartoonizer",
"instruction-tuning-vision/instruction-tuned-cartoonizer",
"timbrooks/instruct-pix2pix",
],
)
parser.add_argument("--dataset_id", type=str, default="imagenette")
parser.add_argument("--max_num_samples", type=int, default=10)
parser.add_argument(
"--prompt", type=str, default="Generate a cartoonized version of the image"
)
parser.add_argument("--num_inference_steps", type=int, default=20)
parser.add_argument("--image_guidance_scale", type=float, default=1.5)
parser.add_argument("--guidance_scale", type=float, default=7.0)
args = parser.parse_args()
return args
def load_pipeline(model_id: str):
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, use_auth_token=True
).to("cuda")
pipeline.enable_xformers_memory_efficient_attention()
pipeline.set_progress_bar_config(disable=True)
return pipeline
def main(args):
data_root = os.path.join(f"comparison-{args.model_id}")
print("Loading validation dataset and inference model...")
dataset = data_utils.load_dataset(args.dataset_id, args.max_num_samples)
using_tf = False
if "sayakpaul" in args.model_id:
inference = model_utils.load_model(args.model_id)
using_tf = True
print(
"TensorFlow model detected for inference, Diffusion-specifc parameters won't be used."
)
else:
inference = load_pipeline(args.model_id)
num_samples_to_generate = (
args.max_num_samples
if args.max_num_samples is not None
else dataset.cardinality()
)
print(f"Generating {num_samples_to_generate} images...")
for sample in dataset.as_numpy_iterator():
# Result dir creation.
concept_path = os.path.join(data_root, str(sample["label"]))
hash_image = hashlib.sha1(sample["image"].tobytes()).hexdigest()
image_path = os.path.join(concept_path, hash_image)
os.makedirs(image_path, exist_ok=True)
# Perform inference and serialize the result.
if using_tf:
image = model_utils.perform_inference(inference)(sample["image"])
Image.fromarray(sample["image"]).save(os.path.join(image_path, "original.png"))
image.save(os.path.join(image_path, "tf_image.png"))
else:
image = inference(
args.prompt,
image=Image.fromarray(sample["image"]).convert("RGB"),
num_inference_steps=args.num_inference_steps,
image_guidance_scale=args.image_guidance_scale,
guidance_scale=args.guidance_scale,
generator=GEN,
).images[0]
image_prefix = f"steps@{args.num_inference_steps}-igs@{args.image_guidance_scale}-gs@{args.guidance_scale}"
Image.fromarray(sample["image"]).save(os.path.join(image_path, "original.png"))
image.save(os.path.join(image_path, f"{image_prefix}.png"))
if __name__ == "__main__":
args = parse_args()
main(args)