in 2_slm-fine-tuning-mlstudio/phi/src_train/train_mlflow.py [0:0]
def parse_args():
# setup argparse
parser = argparse.ArgumentParser()
# curr_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
# hyperparameters
parser.add_argument(
"--model_name_or_path",
default="microsoft/Phi-3.5-mini-instruct",
type=str,
help="Input directory for training",
)
parser.add_argument(
"--train_dir", default="data", type=str, help="Input directory for training"
)
parser.add_argument(
"--model_dir", default="./model", type=str, help="output directory for model"
)
parser.add_argument("--epochs", default=1, type=int, help="number of epochs")
parser.add_argument(
"--train_batch_size",
default=8,
type=int,
help="training - mini batch size for each gpu/process",
)
parser.add_argument(
"--eval_batch_size",
default=8,
type=int,
help="evaluation - mini batch size for each gpu/process",
)
parser.add_argument(
"--learning_rate", default=5e-06, type=float, help="learning rate"
)
parser.add_argument("--logging_steps", default=2, type=int, help="logging steps")
parser.add_argument("--save_steps", default=100, type=int, help="save steps")
parser.add_argument(
"--grad_accum_steps", default=4, type=int, help="gradient accumulation steps"
)
parser.add_argument("--lr_scheduler_type", default="linear", type=str)
parser.add_argument("--seed", default=0, type=int, help="seed")
parser.add_argument("--warmup_ratio", default=0.2, type=float, help="warmup ratio")
parser.add_argument(
"--max_seq_length", default=2048, type=int, help="max seq length"
)
parser.add_argument("--save_merged_model", type=bool, default=False)
# lora hyperparameters
parser.add_argument("--lora_r", default=16, type=int, help="lora r")
parser.add_argument("--lora_alpha", default=16, type=int, help="lora alpha")
parser.add_argument("--lora_dropout", default=0.05, type=float, help="lora dropout")
# wandb params
parser.add_argument("--wandb_api_key", type=str, default="")
parser.add_argument("--wandb_project", type=str, default="")
parser.add_argument("--wandb_run_name", type=str, default="")
parser.add_argument(
"--wandb_watch", type=str, default="gradients"
) # options: false | gradients | all
parser.add_argument(
"--wandb_log_model", type=str, default="false"
) # options: false | true
# parse args
args = parser.parse_args()
# return args
return args