def test()

in online_attacks/classifiers/trainer.py [0:0]


    def test(self, epoch: int) -> float:
        self.model.eval()
        if not isinstance(self.attacker, NoAttacker):
            self.attacker.predict.eval()
        test_loss = 0
        correct = 0
        adv_correct = 0
        for data, target in self.test_loader:
            with torch.no_grad():
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                # sum up batch loss
                test_loss += loss.item()
                # get the index of the max log-probability
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(target.view_as(pred)).sum().item()

            if not isinstance(self.attacker, NoAttacker):
                data = self.attacker.perturb(data, target)
                with torch.no_grad():
                    output = self.model(data)
                    pred = output.max(1, keepdim=True)[1]
                    adv_correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.test_loader.dataset)

        acc = 100.0 * correct / len(self.test_loader.dataset)
        adv_acc = 100.0 * adv_correct / len(self.test_loader.dataset)

        if self.logger is None:
            log_output = "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
                test_loss, correct, len(self.test_loader.dataset), acc
            )
            if self.attacker is not None:
                log_output += ", Adv Accuracy: {}/{} ({:.0f}".format(
                    adv_correct, len(self.test_loader.dataset), adv_acc
                )
            print(log_output)
        else:
            results = dict(test_loss=test_loss, test_accuracy=acc)
            if self.attacker is not None:
                results["adv_accuracy"] = adv_acc
            self.logger.write(results, epoch)

        return acc