Stable-Diffusion-Vertex/hpo/kohya-lora/train_kohya.py (140 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess import os import argparse import re import torch from safetensors.torch import save_file def main(args): subprocess.run("accelerate config default", shell=True) subprocess.run("cat /root/.cache/huggingface/accelerate/default_config.yaml", shell=True) METHOD = args.method NUM_CPU_THREADS = int(args.num_cpu_threads) MODEL_NAME= args.model_name #"runwayml/stable-diffusion-v1-5" INSTANCE_DIR= args.input_storage METADATA_DIR = args.metadata_storage OUTPUT_DIR= args.output_storage DISPLAY_NAME = args.display_name RESOLUTION = args.resolution MAX_EPOCHS = int(args.max_train_epochs) LR = float(args.lr) UNET_LR = float(args.unet_lr) TEXT_ENCODER_LR = float(args.text_encoder_lr) LR_SCHEDULER = args.lr_scheduler NETWORK_DIM = int(args.network_dim) NETWORK_ALPHA = int(args.network_alpha) BATCH_SIZE = int(args.batch_size) SAVE_N_EPOCHS = int(args.save_every_n_epochs) NETWORK_WEIGHTS = args.network_weights REG_DIR = args.reg_dir USE_8BIT_ADAM = bool(args.use_8bit_adam) USE_LION = bool(args.use_lion) NOISE_OFFSET = float(args.noise_offset) HPO = args.hpo if METHOD == "kohya_lora": os.chdir("/root/lora-scripts") # for complex commands, with many args, use string + `shell=True`: cmd_str = (f'accelerate launch --num_cpu_threads_per_process={NUM_CPU_THREADS} sd-scripts/train_network.py ' f'--enable_bucket ' f'--pretrained_model_name_or_path="{MODEL_NAME}" ' f'--train_data_dir="{INSTANCE_DIR}" ' f'--output_dir="{OUTPUT_DIR}" ' f'--logging_dir="{OUTPUT_DIR}/logs" ' f'--log_prefix="{DISPLAY_NAME}_logs" ' f'--resolution="{RESOLUTION}" ' f'--network_module="networks.lora" ' f'--max_train_epochs={MAX_EPOCHS} ' f'--learning_rate={LR} ' f'--unet_lr={UNET_LR} ' f'--text_encoder_lr={TEXT_ENCODER_LR} ' f'--lr_scheduler="{LR_SCHEDULER}" ' f'--lr_warmup_steps=0 ' f'--lr_scheduler_num_cycles=1 ' f'--network_dim={NETWORK_DIM} ' f'--network_alpha={NETWORK_ALPHA} ' f'--output_name="{DISPLAY_NAME}" ' f'--train_batch_size={BATCH_SIZE} ' f'--save_every_n_epochs={SAVE_N_EPOCHS} ' f'--mixed_precision="fp16" ' f'--save_precision="fp16" ' f'--seed="1337" ' f'--cache_latents ' f'--clip_skip=2 ' f'--prior_loss_weight=1 ' f'--max_token_length=225 ' f'--caption_extension=".txt" ' f'--save_model_as="safetensors" ' f'--min_bucket_reso=256 ' f'--max_bucket_reso=1024 ' f'--keep_tokens=0 ' f'--xformers --shuffle_caption ' f'--hpo="{HPO}"') if NETWORK_WEIGHTS: cmd_str += f' --network_weights="{NETWORK_WEIGHTS}"' if REG_DIR: cmd_str += f' --reg_data_dir="{REG_DIR}"' if USE_8BIT_ADAM == True: cmd_str += f' --use_8bit_adam' if USE_LION == True: cmd_str += f' --use_lion_optimizer' if NOISE_OFFSET: cmd_str += f' --noise_offset={NOISE_OFFSET}' if METADATA_DIR is not None: cmd_str += f' --in_json="{METADATA_DIR}"' # start training subprocess.run(cmd_str, shell=True) if bool(args.save_nfs) == True: nfs_path = args.nfs_mnt_dir if not os.path.exists(nfs_path): print("nfs not exist") else: if not os.path.exists(nfs_path + '/kohya'): os.mkdir(nfs_path + '/kohya') print(f"{nfs_path}/kohya has been created.") else: print(f"{nfs_path}/kohya already exists.") copy_cmd = f'cp {OUTPUT_DIR}/*.safetensors {nfs_path}/kohya' subprocess.run(copy_cmd, shell=True) subprocess.run(f'ls {nfs_path}/kohya', shell=True) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--method", type=str, default="kohya_lora", help="a tag") parser.add_argument("--num_cpu_threads", type=int, default=8, help="num of cpu threads per process") parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="bucket_name/model_folder") parser.add_argument("--input_storage", type=str,default="/root/dog_image_resize", help="/gcs/bucket_name/input_image_folder") parser.add_argument("--metadata_storage", type=str, default=None, help="metadata json path, for native training") parser.add_argument("--output_storage", type=str, default="/root/dog_output", help="/gcs/bucket_name/output_folder") parser.add_argument("--display_name", type=str, default="sks_dog", help="prompt") parser.add_argument("--resolution", type=str, default="512,512", help="resolution group") parser.add_argument("--max_train_epochs", type=int, default=10, help="max train epochs") parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") parser.add_argument("--unet_lr", type=float, default=1e-4, help="unet learning rate") parser.add_argument("--text_encoder_lr", type=float, default=1e-5, help="text encoder learning rate") parser.add_argument("--lr_scheduler", type=str, default="cosine_with_restarts", help="") parser.add_argument("--network_dim", type=int, default=32, help="network dim 4~128") parser.add_argument("--network_alpha", type=int, default=32, help="often=network dim") parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--save_every_n_epochs", type=int, default=2, help="save every n epochs") parser.add_argument("--network_weights", type=str, default="", help="lora model path,/gcs/bucket_name/lora_model") parser.add_argument("--reg_dir", type=str, default="", help="regularization data path") parser.add_argument("--use_8bit_adam", type=bool, default=True, help="use 8bit adam optimizer") parser.add_argument("--use_lion", type=bool, default=False, help="lion optimizer") parser.add_argument("--noise_offset", type=int, default=0, help="0.1 if use") parser.add_argument("--save_nfs", type=bool, default=False, help="if save the model to file store") parser.add_argument("--save_nfs_only", type=bool, default=False, help="only copy file from gcs to filestore, no training") parser.add_argument("--nfs_mnt_dir", type=str, default="/mnt/nfs/model_repo", help="Filestore's mount directory") parser.add_argument("--hpo", type=str, default="n", help="if using hyper parameter tuning") args = parser.parse_args() print(args) if bool(args.save_nfs_only) == True: nfs_path = args.nfs_mnt_dir #"/mnt/nfs/model_repo" if not os.path.exists(nfs_path): print("nfs not exist") else: if not os.path.exists(nfs_path + '/kohya'): os.mkdir(nfs_path + '/kohya') print(f"{nfs_path}/kohya has been created.") else: print(f"{nfs_path}/kohya already exists.") copy_cmd = f'cp {args.output_storage}/*.safetensors {nfs_path}/kohya' subprocess.run(copy_cmd, shell=True) subprocess.run(f'ls {nfs_path}/kohya', shell=True) else: main(args)