# Copyright 2020 Google LLC All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
import os
import argparse
import re
import torch
from safetensors.torch import save_file

def main(args):
    
    subprocess.run("accelerate config default", shell=True)
    subprocess.run("cat /root/.cache/huggingface/accelerate/default_config.yaml", shell=True)

    METHOD = args.method
    NUM_CPU_THREADS = int(args.num_cpu_threads)
    MODEL_NAME= args.model_name #"runwayml/stable-diffusion-v1-5"
    INSTANCE_DIR= args.input_storage
    METADATA_DIR = args.metadata_storage
    OUTPUT_DIR= args.output_storage
    DISPLAY_NAME = args.display_name
    RESOLUTION = args.resolution
    MAX_EPOCHS = int(args.max_train_epochs)
    LR = float(args.lr)
    UNET_LR = float(args.unet_lr)
    TEXT_ENCODER_LR = float(args.text_encoder_lr)
    LR_SCHEDULER = args.lr_scheduler
    NETWORK_DIM = int(args.network_dim)
    NETWORK_ALPHA = int(args.network_alpha)
    BATCH_SIZE = int(args.batch_size)
    SAVE_N_EPOCHS = int(args.save_every_n_epochs)
    NETWORK_WEIGHTS = args.network_weights
    REG_DIR = args.reg_dir
    USE_8BIT_ADAM = bool(args.use_8bit_adam)
    USE_LION = bool(args.use_lion)
    NOISE_OFFSET = float(args.noise_offset)
    
    if METHOD == "kohya_lora":
        os.chdir("/root/lora-scripts")
        # for complex commands, with many args, use string + `shell=True`:
        cmd_str = (f'accelerate launch --num_cpu_threads_per_process={NUM_CPU_THREADS} sd-scripts/train_network.py '
                   f'--enable_bucket '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--train_data_dir="{INSTANCE_DIR}" '
                   f'--output_dir="{OUTPUT_DIR}" '
                   f'--logging_dir="{OUTPUT_DIR}/logs" '
                   f'--log_prefix="{DISPLAY_NAME}_logs" '
                   f'--resolution="{RESOLUTION}" '
                   f'--network_module="networks.lora" '
                   f'--max_train_epochs={MAX_EPOCHS} '
                   f'--learning_rate={LR} '
                   f'--unet_lr={UNET_LR} '
                   f'--text_encoder_lr={TEXT_ENCODER_LR} '
                   f'--lr_scheduler="{LR_SCHEDULER}" '
                   f'--lr_warmup_steps=0 '
                   f'--lr_scheduler_num_cycles=1 '
                   f'--network_dim={NETWORK_DIM} '
                   f'--network_alpha={NETWORK_ALPHA} '
                   f'--output_name="{DISPLAY_NAME}" '
                   f'--train_batch_size={BATCH_SIZE} '
                   f'--save_every_n_epochs={SAVE_N_EPOCHS} '
                   f'--mixed_precision="fp16" '
                   f'--save_precision="fp16" '
                   f'--seed="1337" '
                   f'--cache_latents '
                   f'--clip_skip=2 '
                   f'--prior_loss_weight=1 '
                   f'--max_token_length=225 '
                   f'--caption_extension=".txt" '
                   f'--save_model_as="safetensors" '
                   f'--min_bucket_reso=256 '
                   f'--max_bucket_reso=1024 '
                   f'--keep_tokens=0 '
                   f'--xformers --shuffle_caption ')
                                    
        if NETWORK_WEIGHTS:
            cmd_str += f' --network_weights="{NETWORK_WEIGHTS}"'
        if REG_DIR:
            cmd_str += f' --reg_data_dir="{REG_DIR}"'
        if USE_8BIT_ADAM == True:
            cmd_str += f' --use_8bit_adam'
        if USE_LION == True:
            cmd_str += f' --use_lion_optimizer'
        if NOISE_OFFSET:
            cmd_str += f' --noise_offset={NOISE_OFFSET}'
        if METADATA_DIR is not None:
            cmd_str += f' --in_json="{METADATA_DIR}"'
    
    # start training
    subprocess.run(cmd_str, shell=True)

    if bool(args.save_nfs) == True:
        nfs_path = "/mnt/nfs/model_repo"

        if not os.path.exists(nfs_path):
            print("nfs not exist")
        else:
            if not os.path.exists(nfs_path + '/model'):
               os.mkdir(nfs_path + '/model')
               print(f"{nfs_path}/model has been created.")
            else:
               print(f"{nfs_path}/model already exists.")
            copy_cmd = f'cp {OUTPUT_DIR}/*.safetensors {nfs_path}/model'
            subprocess.run(copy_cmd, shell=True)
            subprocess.run(f'ls {nfs_path}/model', shell=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, default="kohya_lora", help="a tag")
    parser.add_argument("--num_cpu_threads", type=int, default=8, help="num of cpu threads per process")
    parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="bucket_name/model_folder")
    parser.add_argument("--input_storage", type=str,default="/root/dog_image_resize", help="/gcs/bucket_name/input_image_folder")
    parser.add_argument("--metadata_storage", type=str, default=None, help="metadata json path, for native training")
    parser.add_argument("--output_storage", type=str, default="/root/dog_output", help="/gcs/bucket_name/output_folder")
    parser.add_argument("--display_name", type=str, default="sks_dog", help="prompt")
    parser.add_argument("--resolution", type=str, default="512,512", help="resolution group")
    parser.add_argument("--max_train_epochs", type=int, default=10, help="max train epochs")
    parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
    parser.add_argument("--unet_lr", type=float, default=1e-4, help="unet learning rate")
    parser.add_argument("--text_encoder_lr", type=float, default=1e-5, help="text encoder learning rate")
    parser.add_argument("--lr_scheduler", type=str, default="cosine_with_restarts", help="")
    parser.add_argument("--network_dim", type=int, default=32, help="network dim 4~128")
    parser.add_argument("--network_alpha", type=int, default=32, help="often=network dim")
    parser.add_argument("--batch_size", type=int, default=1, help="batch size")
    parser.add_argument("--save_every_n_epochs", type=int, default=2, help="save every n epochs")
    parser.add_argument("--network_weights", type=str, default="", help="lora model path,/gcs/bucket_name/lora_model")
    parser.add_argument("--reg_dir", type=str, default="", help="regularization data path")
    parser.add_argument("--use_8bit_adam", type=bool, default=True, help="use 8bit adam optimizer")
    parser.add_argument("--use_lion", type=bool, default=False, help="lion optimizer")
    parser.add_argument("--noise_offset", type=int, default=0, help="0.1 if use")
    parser.add_argument("--save_nfs", type=bool, default=False, help="if save the model to file store")
    parser.add_argument("--save_nfs_only", type=bool, default=False, help="only copy file from gcs to filestore, no training")

    args = parser.parse_args()
    print(args)
    if bool(args.save_nfs_only) == True:
        nfs_path = "/mnt/nfs/model_repo"
        if not os.path.exists(nfs_path):
            print("nfs not exist")
        else:
            if not os.path.exists(nfs_path + '/model'):
               os.mkdir(nfs_path + '/model')
               print(f"{nfs_path}/model has been created.")
            else:
               print(f"{nfs_path}/model already exists.")
            copy_cmd = f'cp {args.output_storage}/*.safetensors {nfs_path}/model'
            subprocess.run(copy_cmd, shell=True)
            subprocess.run(f'ls {nfs_path}/model', shell=True) 
    else:
       main(args)
