def train()

in online_attacks/classifiers/trainer.py [0:0]


    def train(self, epoch: int):
        self.model.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            if not isinstance(self.attacker, NoAttacker):
                with ctx_noparamgrad_and_eval(self.attacker.predict):
                    data = self.attacker.perturb(data, target)

            self.optimizer.zero_grad()

            output = self.model(data)
            loss = self.criterion(output, target)

            if isinstance(self.optimizer, Sls):

                def closure():
                    output = self.model(data)
                    loss = self.criterion(output, target).mean()
                    return loss

                self.optimizer.step(closure)
            else:
                loss.backward()
                self.optimizer.step()

            train_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            if batch_idx % 10 == 0:
                print(
                    f"Train Epoch: {epoch:d} [{batch_idx * len(data):d}/{len(self.train_loader.dataset):d} "
                    f"{100. * batch_idx / len(self.train_loader):.0f}] \tLoss: {loss.item():.6f} | "
                    f"Acc: {100. * correct / total:.3f}"
                )

        if self.logger is not None:
            self.logger.write(
                dict(train_accuracy=100.0 * correct / total, loss=loss.item()), epoch
            )