def main()

in Stable-Diffusion-Vertex/Diffusers/train.py [0:0]


def main(args):
    
    subprocess.run("accelerate config default", shell=True)
    subprocess.run("cat /root/.cache/huggingface/accelerate/default_config.yaml", shell=True)

    METHOD = args.method
    MODEL_NAME= args.model_name #"runwayml/stable-diffusion-v1-5"
    INSTANCE_DIR= args.input_storage
    OUTPUT_DIR= args.output_storage
    PROMPT = args.prompt
    CLASS_PROMPT = args.class_prompt
    NUM_CLASS_IMAGES = int(args.num_class_images)
    STEPS = int(args.max_train_steps)
    TEXT_ENCODER = bool(args.text_encoder)
    SET_GRADS_TO_NONE = bool(args.set_grads_to_none)
    
    RESOLUTION = int(args.resolution)
    BATCH_SIZE = int(args.batch_size)
    USE_8BIT = bool(args.use_8bit)
    LR = float(args.lr)
    GRADIENT_ACCU_STEPS = int(args.gradient_accumulation_steps)
    NUM_VALID_IMG = int(args.num_validation_images)
    VALID_PRMOP = args.validation_prompt
    

    # Note the constraint: raise error: (args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1)
    
    if METHOD == "diffuser_dreambooth":
        os.chdir("/root/diffusers/examples/dreambooth")
        # for complex commands, with many args, use string + `shell=True`:
        cmd_str = (f'accelerate launch train_dreambooth.py '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--instance_data_dir="{INSTANCE_DIR}" '
                   f'--output_dir="{OUTPUT_DIR}" '
                   f'--instance_prompt="{PROMPT}" '
                   f'--class_data_dir="{OUTPUT_DIR}/class_data" '
                   f'--with_prior_preservation --prior_loss_weight=1.0 '
                   f'--class_prompt="{CLASS_PROMPT}" '
                   f'--resolution={RESOLUTION} '
                   f'--train_batch_size={BATCH_SIZE} '                   
                   f'--gradient_checkpointing '
                   f'--gradient_accumulation_steps={GRADIENT_ACCU_STEPS} '
                   f'--mixed_precision="fp16" '
                   f'--learning_rate={LR} '
                   f'--lr_scheduler="constant" '
                   f'--lr_warmup_steps=0 '
                   f'--num_class_images={NUM_CLASS_IMAGES} '
                   f'--enable_xformers_memory_efficient_attention '
                   f'--max_train_steps={STEPS}')
        
        if TEXT_ENCODER == True:
            cmd_str += f' --train_text_encoder'
        if SET_GRADS_TO_NONE == True:
            cmd_str += f' --set_grads_to_none'
        if USE_8BIT == True:
            cmd_str += f' --use_8bit_adam'
        
    elif METHOD == "diffuser_dreambooth_lora":
        os.chdir("/root/diffusers/examples/dreambooth")
        # for complex commands, with many args, use string + `shell=True`:
        cmd_str = (f'accelerate launch train_dreambooth_lora.py '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--instance_data_dir="{INSTANCE_DIR}" '
                   f'--output_dir="{OUTPUT_DIR}" '
                   f'--instance_prompt="{PROMPT}" '
                   f'--resolution={RESOLUTION} '
                   f'--train_batch_size={BATCH_SIZE} '
                   f'--mixed_precision="fp16" '
                   f'--gradient_accumulation_steps={GRADIENT_ACCU_STEPS} '
                   f'--learning_rate={LR} '
                   f'--lr_scheduler="constant" '
                   f'--lr_warmup_steps=0 '
                   f'--max_train_steps={STEPS}')
    
        if USE_8BIT == True:
            cmd_str += f' --use_8bit_adam'

    elif METHOD == "diffuser_text_to_image":
        os.chdir("/root/diffusers/examples/text_to_image")
        cmd_str = (f'accelerate launch --mixed_precision="fp16" train_text_to_image.py '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--train_data_dir="{INSTANCE_DIR}" '
                   f'--use_ema '
                   f'--resolution={RESOLUTION} --center_crop --random_flip '
                   f'--mixed_precision="fp16" '
                   f'--train_batch_size={BATCH_SIZE} '
                   f'--gradient_accumulation_steps={GRADIENT_ACCU_STEPS} '
                   f'--gradient_checkpointing '
                   f'--max_train_steps={STEPS} '
                   f'--learning_rate={LR} '
                   f'--max_grad_norm=1 '
                   f'--lr_scheduler="constant" '
                   f'--lr_warmup_steps=0 '
                   f'--enable_xformers_memory_efficient_attention '
                   f'--output_dir="{OUTPUT_DIR}"')

        if USE_8BIT == True:
            cmd_str += f' --use_8bit_adam'

    elif METHOD == "diffuser_text_to_image_lora":
        os.chdir("/root/diffusers/examples/text_to_image")
        cmd_str = (f'accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--train_data_dir="{INSTANCE_DIR}" '
                   f'--resolution={RESOLUTION} --center_crop --random_flip '
                   f'--mixed_precision="fp16" '
                   f'--train_batch_size={BATCH_SIZE} '
                   f'--gradient_accumulation_steps={GRADIENT_ACCU_STEPS} '
                   f'--gradient_checkpointing '
                   f'--max_train_steps={STEPS} '
                   f'--learning_rate={LR} '
                   f'--max_grad_norm=1 '
                   f'--lr_scheduler="constant" --lr_warmup_steps=0 '
                   f'--seed=42 '
                   f'--num_validation_images={NUM_VALID_IMG} '
                   f'--validation_prompt="{VALID_PRMOP}" '
                   f'--output_dir="{OUTPUT_DIR}"')
        
    elif METHOD == "peft_lora":
        os.chdir("/root/peft/examples/lora_dreambooth")
                
        #create class data diretory
        if not os.path.exists(OUTPUT_DIR):
            os.mkdir(OUTPUT_DIR)
            print(f"{OUTPUT_DIR} has been created.")
        else:
            print(f"{OUTPUT_DIR} already exists.")
            
        class_directory = OUTPUT_DIR + '/class_data'
        if not os.path.exists(class_directory):
            os.mkdir(class_directory)
            print(f"{class_directory} has been created.")
        else:
            print(f"{class_directory} already exists.")

        # for complex commands, with many args, use string + `shell=True`:
        cmd_str = (f'accelerate launch train_dreambooth.py '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--instance_data_dir="{INSTANCE_DIR}" '
                   f'--output_dir="{OUTPUT_DIR}" '
                   f'--with_prior_preservation '
                   f'--prior_loss_weight=1 '
                   f'--num_class_images={NUM_CLASS_IMAGES} '
                   f'--class_prompt="{CLASS_PROMPT}" '
                   f'--class_data_dir="{OUTPUT_DIR}/class_data" '
                   f'--instance_prompt="{PROMPT}" '
                   f'--use_lora '
                   f'--lora_r=4 '
                   f'--lora_alpha=4 '
                   f'--lora_bias=none '
                   f'--lora_dropout=0.0 '
                   f'--lora_text_encoder_r=4 '
                   f'--lora_text_encoder_alpha=4 '
                   f'--lora_text_encoder_bias=none '
                   f'--lora_text_encoder_dropout=0.0 '
                   f'--gradient_checkpointing '
                   f'--resolution=512 '
                   f'--train_batch_size=1 '
                   f'--use_8bit_adam '
                   f'--mixed_precision="fp16" '
                   f'--gradient_accumulation_steps=1 '
                   f'--learning_rate=1e-4 '
                   f'--lr_scheduler="constant" '
                   f'--lr_warmup_steps=0 '
                   f'--enable_xformers_memory_efficient_attention '
                   f'--max_train_steps={STEPS}')
        if TEXT_ENCODER == True:
            cmd_str += f' --train_text_encoder '


    
    # start training
    subprocess.run(cmd_str, shell=True)

    # convert to safetensors
    if (METHOD == "diffuser_dreambooth_lora") or (METHOD == "diffuser_text_to_image_lora"):
        bin_to_safetensors(args.output_storage)

    if (METHOD == "diffuser_dreambooth") or (METHOD == "diffuser_text_to_image"):
        subprocess.run(f'python3 /root/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py --model_path {OUTPUT_DIR} --checkpoint_path {OUTPUT_DIR}/dreambooth.safetensors --use_safetensors', shell=True)

    if bool(args.save_nfs) == True:
        nfs_path = args.nfs_mnt_dir

        if not os.path.exists(nfs_path):
            print("nfs not exist")
        else:
            if not os.path.exists(nfs_path + '/' + args.method):
               os.mkdir(nfs_path + '/' + args.method)
               print(f"{nfs_path}/{args.method} has been created.")
            else:
               print(f"{nfs_path}/{args.method} already exists.")
            copy_cmd = f'cp {OUTPUT_DIR}/*.safetensors {nfs_path}/{args.method}'
            subprocess.run(copy_cmd, shell=True)
            subprocess.run(f'ls {nfs_path}/{args.method}', shell=True)