in distilvit/train.py [0:0]
def get_arg_parser(root_dir=None):
if root_dir is None:
root_dir = os.path.join(os.path.dirname(__file__), "..")
parser = argparse.ArgumentParser(
description="Train a Vision Encoder Decoder Model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model-id",
default=MODEL_ID,
type=str,
help="Model ID",
)
parser.add_argument(
"--sample",
default=None,
type=int,
help="Sample data",
)
parser.add_argument(
"--tag",
type=str,
help="HF tag",
default=None,
)
parser.add_argument(
"--save-dir",
default=root_dir,
type=str,
help="Save dir",
)
parser.add_argument(
"--cache-dir",
default=os.path.join(root_dir, "cache"),
type=str,
help="Cache dir",
)
parser.add_argument(
"--prune-cache",
default=False,
action="store_true",
help="Empty cache dir",
)
parser.add_argument(
"--checkpoints-dir",
default=os.path.join(root_dir, "checkpoints"),
type=str,
help="Checkpoints dir",
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="Debug mode",
)
parser.add_argument(
"--num-train-epochs", type=int, default=3, help="Number of epochs"
)
parser.add_argument("--eval-steps", type=int, default=100, help="Evaluation steps")
parser.add_argument("--save-steps", type=int, default=100, help="Save steps")
parser.add_argument(
"--encoder-model",
# default="google/vit-base-patch16-224-in21k",
default="google/vit-base-patch16-224",
type=str,
help="Base model for the encoder",
)
parser.add_argument(
"--base-model",
default=None,
type=str,
help="Base model to train again from",
)
parser.add_argument(
"--device",
default=get_device(),
type=str,
choices=["cpu", "cuda", "mps"],
help="Base model to train again from",
)
parser.add_argument(
"--base-model-revision",
default=None,
type=str,
help="Base model revision",
)
parser.add_argument("--push-to-hub", action="store_true", help="Push to hub")
parser.add_argument(
"--feature-extractor-model",
default="google/vit-base-patch16-224-in21k",
#default="google/vit-base-patch16-224",
type=str,
help="Feature extractor model for the encoder",
)
parser.add_argument(
"--decoder-model",
default="distilbert/distilgpt2",
type=str,
help="Model for the decoder",
)
parser.add_argument(
"--dataset",
nargs="+",
choices=list(DATASETS.keys()),
help="Dataset to use for training",
)
return parser