in online_attacks/scripts/train_classifiers.py [0:0]
def run(cls, args):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
if args.dataset == DatasetType.MNIST:
import online_attacks.classifiers.mnist as mnist
params = OmegaConf.structured(mnist.TrainingParams)
params.model_type = args.model_type
params.train_on_test = args.train_on_test
params.name = ""
if args.robust:
params.attacker = Attacker.PGD_ATTACK
params.name = "%s_" % params.attacker.name
params.name += "test_" if params.train_on_test else "train_"
params.name += str(args.name)
if args.model_attacker is not None:
params.model_attacker = args.model_attacker
mnist.train(params, device=device)
elif args.dataset == DatasetType.CIFAR:
import online_attacks.classifiers.cifar as cifar
params = OmegaConf.structured(cifar.TrainingParams)
params.model_type = args.model_type
if params.model_type in [CifarModel.GOOGLENET, CifarModel.WIDE_RESNET]:
params.dataset_params.batch_size = 64
params.dataset_params.test_batch_size = 256
elif params.model_type == CifarModel.DENSE_121:
params.dataset_params.batch_size = 128
params.train_on_test = args.train_on_test
params.name = ""
if args.robust:
params.attacker = Attacker.PGD_ATTACK
params.name = "%s_" % params.attacker.name
params.name += "test_" if params.train_on_test else "train_"
params.name += str(args.name)
if args.model_attacker is not None:
params.model_attacker = args.model_attacker
cifar.train(params, device=device)
else:
raise ValueError()