def main()

in Stable-Diffusion-Vertex/Kohya-lora/train_kohya.py [0:0]


def main(args):
    
    subprocess.run("accelerate config default", shell=True)
    subprocess.run("cat /root/.cache/huggingface/accelerate/default_config.yaml", shell=True)

    METHOD = args.method
    NUM_CPU_THREADS = int(args.num_cpu_threads)
    MODEL_NAME= args.model_name #"runwayml/stable-diffusion-v1-5"
    INSTANCE_DIR= args.input_storage
    METADATA_DIR = args.metadata_storage
    OUTPUT_DIR= args.output_storage
    DISPLAY_NAME = args.display_name
    RESOLUTION = args.resolution
    MAX_EPOCHS = int(args.max_train_epochs)
    LR = float(args.lr)
    UNET_LR = float(args.unet_lr)
    TEXT_ENCODER_LR = float(args.text_encoder_lr)
    LR_SCHEDULER = args.lr_scheduler
    NETWORK_DIM = int(args.network_dim)
    NETWORK_ALPHA = int(args.network_alpha)
    BATCH_SIZE = int(args.batch_size)
    SAVE_N_EPOCHS = int(args.save_every_n_epochs)
    NETWORK_WEIGHTS = args.network_weights
    REG_DIR = args.reg_dir
    USE_8BIT_ADAM = bool(args.use_8bit_adam)
    USE_LION = bool(args.use_lion)
    NOISE_OFFSET = float(args.noise_offset)
    
    if METHOD == "kohya_lora":
        os.chdir("/root/lora-scripts")
        # for complex commands, with many args, use string + `shell=True`:
        cmd_str = (f'accelerate launch --num_cpu_threads_per_process={NUM_CPU_THREADS} sd-scripts/train_network.py '
                   f'--enable_bucket '
                   f'--pretrained_model_name_or_path="{MODEL_NAME}" '
                   f'--train_data_dir="{INSTANCE_DIR}" '
                   f'--output_dir="{OUTPUT_DIR}" '
                   f'--logging_dir="{OUTPUT_DIR}/logs" '
                   f'--log_prefix="{DISPLAY_NAME}_logs" '
                   f'--resolution="{RESOLUTION}" '
                   f'--network_module="networks.lora" '
                   f'--max_train_epochs={MAX_EPOCHS} '
                   f'--learning_rate={LR} '
                   f'--unet_lr={UNET_LR} '
                   f'--text_encoder_lr={TEXT_ENCODER_LR} '
                   f'--lr_scheduler="{LR_SCHEDULER}" '
                   f'--lr_warmup_steps=0 '
                   f'--lr_scheduler_num_cycles=1 '
                   f'--network_dim={NETWORK_DIM} '
                   f'--network_alpha={NETWORK_ALPHA} '
                   f'--output_name="{DISPLAY_NAME}" '
                   f'--train_batch_size={BATCH_SIZE} '
                   f'--save_every_n_epochs={SAVE_N_EPOCHS} '
                   f'--mixed_precision="fp16" '
                   f'--save_precision="fp16" '
                   f'--seed="1337" '
                   f'--cache_latents '
                   f'--clip_skip=2 '
                   f'--prior_loss_weight=1 '
                   f'--max_token_length=225 '
                   f'--caption_extension=".txt" '
                   f'--save_model_as="safetensors" '
                   f'--min_bucket_reso=256 '
                   f'--max_bucket_reso=1024 '
                   f'--keep_tokens=0 '
                   f'--xformers --shuffle_caption ')
                                    
        if NETWORK_WEIGHTS:
            cmd_str += f' --network_weights="{NETWORK_WEIGHTS}"'
        if REG_DIR:
            cmd_str += f' --reg_data_dir="{REG_DIR}"'
        if USE_8BIT_ADAM == True:
            cmd_str += f' --use_8bit_adam'
        if USE_LION == True:
            cmd_str += f' --use_lion_optimizer'
        if NOISE_OFFSET:
            cmd_str += f' --noise_offset={NOISE_OFFSET}'
        if METADATA_DIR is not None:
            cmd_str += f' --in_json="{METADATA_DIR}"'
    
    # start training
    subprocess.run(cmd_str, shell=True)

    if bool(args.save_nfs) == True:
        nfs_path = args.nfs_mnt_dir

        if not os.path.exists(nfs_path):
            print("nfs not exist")
        else:
            if not os.path.exists(nfs_path + '/kohya'):
               os.mkdir(nfs_path + '/kohya')
               print(f"{nfs_path}/kohya has been created.")
            else:
               print(f"{nfs_path}/kohya already exists.")
            copy_cmd = f'cp {OUTPUT_DIR}/*.safetensors {nfs_path}/kohya'
            subprocess.run(copy_cmd, shell=True)
            subprocess.run(f'ls {nfs_path}/kohya', shell=True)