def parse_args()

in phi3/src_train/train_mlflow.py [0:0]


def parse_args():
    # setup argparse
    parser = argparse.ArgumentParser()
    # curr_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

    # hyperparameters
    parser.add_argument("--model_name_or_path", default="microsoft/Phi-3.5-mini-instruct", type=str, help="Input directory for training")    
    parser.add_argument("--train_dir", default="data", type=str, help="Input directory for training")
    parser.add_argument("--model_dir", default="./model", type=str, help="output directory for model")
    parser.add_argument("--epochs", default=1, type=int, help="number of epochs")
    parser.add_argument("--train_batch_size", default=8, type=int, help="training - mini batch size for each gpu/process")
    parser.add_argument("--eval_batch_size", default=8, type=int, help="evaluation - mini batch size for each gpu/process")
    parser.add_argument("--learning_rate", default=5e-06, type=float, help="learning rate")
    parser.add_argument("--logging_steps", default=2, type=int, help="logging steps")
    parser.add_argument("--save_steps", default=100, type=int, help="save steps")    
    parser.add_argument("--grad_accum_steps", default=4, type=int, help="gradient accumulation steps")
    parser.add_argument("--lr_scheduler_type", default="linear", type=str)
    parser.add_argument("--seed", default=0, type=int, help="seed")
    parser.add_argument("--warmup_ratio", default=0.2, type=float, help="warmup ratio")
    parser.add_argument("--max_seq_length", default=2048, type=int, help="max seq length")
    parser.add_argument("--save_merged_model", type=bool, default=False)

    # lora hyperparameters
    parser.add_argument("--lora_r", default=16, type=int, help="lora r")
    parser.add_argument("--lora_alpha", default=16, type=int, help="lora alpha")
    parser.add_argument("--lora_dropout", default=0.05, type=float, help="lora dropout")
    
    # wandb params
    parser.add_argument("--wandb_api_key", type=str, default="")
    parser.add_argument("--wandb_project", type=str, default="")
    parser.add_argument("--wandb_run_name", type=str, default="")
    parser.add_argument("--wandb_watch", type=str, default="gradients") # options: false | gradients | all
    parser.add_argument("--wandb_log_model", type=str, default="false") # options: false | true

    # parse args
    args = parser.parse_args()

    # return args
    return args