in online_attacks/scripts/train_classifiers.py [0:0]
def create_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
default=DatasetType.MNIST,
type=DatasetType,
choices=DatasetType,
)
# Hack to be able to parse either MnistModel or CifarModel
args, _ = parser.parse_known_args()
if args.dataset == DatasetType.MNIST:
parser.add_argument(
"--model_type",
nargs="+",
default=MnistModel.MODEL_A,
type=MnistModel,
choices=MnistModel,
)
elif args.dataset == DatasetType.CIFAR:
parser.add_argument(
"--model_type",
nargs="+",
default=CifarModel.VGG_16,
type=CifarModel,
choices=CifarModel,
)
parser.add_argument("--train_on_test", action="store_true")
parser.add_argument("--num_models", default=1, type=int)
parser.add_argument("--slurm", type=str, default="")
parser.add_argument("--robust", action="store_true")
parser.add_argument("--model_attacker", default=None, type=str)
return parser