def run()

in online_attacks/scripts/train_classifiers.py [0:0]


    def run(cls, args):
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        if args.dataset == DatasetType.MNIST:
            import online_attacks.classifiers.mnist as mnist

            params = OmegaConf.structured(mnist.TrainingParams)
            params.model_type = args.model_type
            params.train_on_test = args.train_on_test
            params.name = ""
            if args.robust:
                params.attacker = Attacker.PGD_ATTACK
                params.name = "%s_" % params.attacker.name
            params.name += "test_" if params.train_on_test else "train_"
            params.name += str(args.name)

            if args.model_attacker is not None:
                params.model_attacker = args.model_attacker
            mnist.train(params, device=device)

        elif args.dataset == DatasetType.CIFAR:
            import online_attacks.classifiers.cifar as cifar

            params = OmegaConf.structured(cifar.TrainingParams)
            params.model_type = args.model_type
            if params.model_type in [CifarModel.GOOGLENET, CifarModel.WIDE_RESNET]:
                params.dataset_params.batch_size = 64
                params.dataset_params.test_batch_size = 256
            elif params.model_type == CifarModel.DENSE_121:
                params.dataset_params.batch_size = 128

            params.train_on_test = args.train_on_test
            params.name = ""
            if args.robust:
                params.attacker = Attacker.PGD_ATTACK
                params.name = "%s_" % params.attacker.name
            params.name += "test_" if params.train_on_test else "train_"
            params.name += str(args.name)

            if args.model_attacker is not None:
                params.model_attacker = args.model_attacker
            cifar.train(params, device=device)

        else:
            raise ValueError()