def add_model_specific_args()

in src/sagemaker_defect_detection/detector.py [0:0]


    def add_model_specific_args(parent_parser):  # pragma: no-cover
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        aa = parser.add_argument
        aa("--train-rpn", action="store_true")
        aa("--train-roi", action="store_true")
        aa("--finetune-rpn", action="store_true")
        aa("--finetune-roi", action="store_true")
        aa("--data-path", metavar="DIR", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
        aa("--backbone", default="resnet34", help="backbone model either resnet34 (default) or resnet50")
        aa("--num-classes", default=7, type=int, metavar="N", help="number of classes including the background")
        aa(
            "-b",
            "--batch-size",
            default=16,
            type=int,
            metavar="N",
            help="mini-batch size (default: 16), this is the total "
            "batch size of all GPUs on the current node when "
            "using Data Parallel or Distributed Data Parallel",
        )
        aa(
            "--lr",
            "--learning-rate",
            default=1e-3,
            type=float,
            metavar="LR",
            help="initial learning rate",
            dest="learning_rate",
        )
        aa("--momentum", default=0.9, type=float, metavar="M", help="momentum")
        aa(
            "--wd",
            "--weight-decay",
            default=1e-4,
            type=float,
            metavar="W",
            help="weight decay (default: 1e-4)",
            dest="weight_decay",
        )
        aa("--seed", type=int, default=123, help="seed for initializing training")
        aa("--pretrained-mfn-ckpt", type=str)
        aa("--pretrained-rpn-ckpt", type=str)
        aa("--pretrained-roi-ckpt", type=str)
        aa("--finetuned-rpn-ckpt", type=str)
        aa("--finetuned-roi-ckpt", type=str)
        aa("--resume-from-checkpoint", type=str)
        aa("--resume-sagemaker-from-checkpoint", type=str, default=os.getenv("SM_CHANNEL_PRETRAINED_CHECKPOINT", None))
        return parser