in online_attacks/classifiers/trainer.py [0:0]
def test(self, epoch: int) -> float:
self.model.eval()
if not isinstance(self.attacker, NoAttacker):
self.attacker.predict.eval()
test_loss = 0
correct = 0
adv_correct = 0
for data, target in self.test_loader:
with torch.no_grad():
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)
# sum up batch loss
test_loss += loss.item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
if not isinstance(self.attacker, NoAttacker):
data = self.attacker.perturb(data, target)
with torch.no_grad():
output = self.model(data)
pred = output.max(1, keepdim=True)[1]
adv_correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(self.test_loader.dataset)
acc = 100.0 * correct / len(self.test_loader.dataset)
adv_acc = 100.0 * adv_correct / len(self.test_loader.dataset)
if self.logger is None:
log_output = "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
test_loss, correct, len(self.test_loader.dataset), acc
)
if self.attacker is not None:
log_output += ", Adv Accuracy: {}/{} ({:.0f}".format(
adv_correct, len(self.test_loader.dataset), adv_acc
)
print(log_output)
else:
results = dict(test_loss=test_loss, test_accuracy=acc)
if self.attacker is not None:
results["adv_accuracy"] = adv_acc
self.logger.write(results, epoch)
return acc