from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
import torch

import os
import glob
import json
import argparse
import numpy as np
from tqdm import tqdm


SAVE_FOLDER = "./load/images_dreamfusion"

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--rank", default=0, type=int, help="# of GPU")

    args = parser.parse_args()

    # stage 1
    stage_1 = DiffusionPipeline.from_pretrained(
        "DeepFloyd/IF-I-XL-v1.0",
        variant="fp16",
        torch_dtype=torch.float16,
        local_files_only=True
    )
    stage_1.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
    stage_1.enable_model_cpu_offload()

    # stage 2
    stage_2 = DiffusionPipeline.from_pretrained(
        "DeepFloyd/IF-II-L-v1.0",
        text_encoder=None,
        variant="fp16",
        torch_dtype=torch.float16,
        local_files_only=True
    )
    # stage_2.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
    stage_2.enable_model_cpu_offload()

    # stage 3
    # safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
    safety_modules = None
    stage_3 = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-x4-upscaler",
        torch_dtype=torch.float16,
        local_files_only=True
    )
    stage_3.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
    stage_3.enable_model_cpu_offload()

    # load prompt library
    with open(os.path.join("load/prompt_library.json"), "r") as f:
        prompt_library = json.load(f)

    n_prompts = len(prompt_library["dreamfusion"]) 
    n_prompts_per_rank = int(np.ceil(n_prompts / 8))

    for prompt in tqdm(prompt_library["dreamfusion"][args.rank * n_prompts_per_rank : (args.rank + 1) * n_prompts_per_rank]):

        print("Prompt:", prompt)

        save_folder = os.path.join(SAVE_FOLDER, prompt)
        os.makedirs(save_folder, exist_ok=True)

        if len(glob.glob(f"{save_folder}/*.png")) >= 30:
            continue

        # enhance prompt
        prompt = prompt + ", 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, hyperrealistic, intricate details, ultra-realistic, award-winning"

        prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
        for _ in tqdm(range(30)):
            seed = np.random.randint(low=0, high=10000000, size=1)[0]
            generator = torch.manual_seed(seed)

            ### Stage 1
            image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
            # pt_to_pil(image)[0].save("./if_stage_I.png")

            ### Stage 2
            image = stage_2(
                image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
            ).images
            # pt_to_pil(image)[0].save("./if_stage_II.png")

            ### Stage 3
            image = stage_3(prompt=prompt, image=(image.float() * 0.5 + 0.5), generator=generator, noise_level=100).images
            image[0].save(f"{save_folder}/img_{seed:08d}.png")