in online_attacks/classifiers/trainer.py [0:0]
def train(self, epoch: int):
self.model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(self.train_loader):
data, target = data.to(self.device), target.to(self.device)
if not isinstance(self.attacker, NoAttacker):
with ctx_noparamgrad_and_eval(self.attacker.predict):
data = self.attacker.perturb(data, target)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
if isinstance(self.optimizer, Sls):
def closure():
output = self.model(data)
loss = self.criterion(output, target).mean()
return loss
self.optimizer.step(closure)
else:
loss.backward()
self.optimizer.step()
train_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if batch_idx % 10 == 0:
print(
f"Train Epoch: {epoch:d} [{batch_idx * len(data):d}/{len(self.train_loader.dataset):d} "
f"{100. * batch_idx / len(self.train_loader):.0f}] \tLoss: {loss.item():.6f} | "
f"Acc: {100. * correct / total:.3f}"
)
if self.logger is not None:
self.logger.write(
dict(train_accuracy=100.0 * correct / total, loss=loss.item()), epoch
)