def parse_args()

in 2_slm-fine-tuning-mlstudio/phi/src_train/train.py [0:0]


def parse_args():
    # setup argparse
    parser = argparse.ArgumentParser()
    # curr_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

    # hyperparameters
    parser.add_argument(
        "--model_name_or_path",
        default="microsoft/Phi-4-mini-instruct",
        type=str,
        help="Input directory for training",
    )
    parser.add_argument(
        "--train_dir", default="data", type=str, help="Input directory for training"
    )
    parser.add_argument(
        "--model_dir", default="./model", type=str, help="output directory for model"
    )
    parser.add_argument("--epochs", default=1, type=int, help="number of epochs")
    parser.add_argument(
        "--output_dir",
        default="./output_dir",
        type=str,
        help="directory to temporarily store when training a model",
    )
    parser.add_argument(
        "--train_batch_size",
        default=2,
        type=int,
        help="training - mini batch size for each gpu/process",
    )
    parser.add_argument(
        "--eval_batch_size",
        default=4,
        type=int,
        help="evaluation - mini batch size for each gpu/process",
    )
    parser.add_argument(
        "--learning_rate", default=5e-06, type=float, help="learning rate"
    )
    parser.add_argument("--logging_steps", default=2, type=int, help="logging steps")
    parser.add_argument("--save_steps", default=100, type=int, help="save steps")
    parser.add_argument(
        "--grad_accum_steps", default=4, type=int, help="gradient accumulation steps"
    )
    parser.add_argument("--lr_scheduler_type", default="linear", type=str)
    parser.add_argument("--seed", default=0, type=int, help="seed")
    parser.add_argument("--warmup_ratio", default=0.2, type=float, help="warmup ratio")
    parser.add_argument(
        "--max_seq_length", default=2048, type=int, help="max seq length"
    )
    parser.add_argument("--save_merged_model", type=bool, default=False)

    # lora hyperparameters
    parser.add_argument("--lora_r", default=16, type=int, help="lora r")
    parser.add_argument("--lora_alpha", default=16, type=int, help="lora alpha")
    parser.add_argument("--lora_dropout", default=0.05, type=float, help="lora dropout")

    # wandb params
    parser.add_argument("--wandb_api_key", type=str, default="")
    parser.add_argument("--wandb_project", type=str, default="")
    parser.add_argument("--wandb_run_name", type=str, default="")
    parser.add_argument(
        "--wandb_watch", type=str, default="gradients"
    )  # options: false | gradients | all
    parser.add_argument(
        "--wandb_log_model", type=str, default="false"
    )  # options: false | true

    # parse args
    args = parser.parse_args()

    # return args
    return args