def get_arg_parser()

in distilvit/train.py [0:0]


def get_arg_parser(root_dir=None):
    if root_dir is None:
        root_dir = os.path.join(os.path.dirname(__file__), "..")

    parser = argparse.ArgumentParser(
        description="Train a Vision Encoder Decoder Model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--model-id",
        default=MODEL_ID,
        type=str,
        help="Model ID",
    )

    parser.add_argument(
        "--sample",
        default=None,
        type=int,
        help="Sample data",
    )

    parser.add_argument(
        "--tag",
        type=str,
        help="HF tag",
        default=None,
    )

    parser.add_argument(
        "--save-dir",
        default=root_dir,
        type=str,
        help="Save dir",
    )

    parser.add_argument(
        "--cache-dir",
        default=os.path.join(root_dir, "cache"),
        type=str,
        help="Cache dir",
    )

    parser.add_argument(
        "--prune-cache",
        default=False,
        action="store_true",
        help="Empty cache dir",
    )

    parser.add_argument(
        "--checkpoints-dir",
        default=os.path.join(root_dir, "checkpoints"),
        type=str,
        help="Checkpoints dir",
    )

    parser.add_argument(
        "--debug",
        default=False,
        action="store_true",
        help="Debug mode",
    )

    parser.add_argument(
        "--num-train-epochs", type=int, default=3, help="Number of epochs"
    )

    parser.add_argument("--eval-steps", type=int, default=100, help="Evaluation steps")
    parser.add_argument("--save-steps", type=int, default=100, help="Save steps")

    parser.add_argument(
        "--encoder-model",
        # default="google/vit-base-patch16-224-in21k",
        default="google/vit-base-patch16-224",
        type=str,
        help="Base model for the encoder",
    )
    parser.add_argument(
        "--base-model",
        default=None,
        type=str,
        help="Base model to train again from",
    )
    parser.add_argument(
        "--device",
        default=get_device(),
        type=str,
        choices=["cpu", "cuda", "mps"],
        help="Base model to train again from",
    )

    parser.add_argument(
        "--base-model-revision",
        default=None,
        type=str,
        help="Base model revision",
    )

    parser.add_argument("--push-to-hub", action="store_true", help="Push to hub")

    parser.add_argument(
        "--feature-extractor-model",
        default="google/vit-base-patch16-224-in21k",
        #default="google/vit-base-patch16-224",
        type=str,
        help="Feature extractor model for the encoder",
    )
    parser.add_argument(
        "--decoder-model",
        default="distilbert/distilgpt2",
        type=str,
        help="Model for the decoder",
    )
    parser.add_argument(
        "--dataset",
        nargs="+",
        choices=list(DATASETS.keys()),
        help="Dataset to use for training",
    )
    return parser