def train()

in online_attacks/classifiers/cifar/models/wide_resnet.py [0:0]


def train(args, logger=None):
    from utils.utils import create_loaders, seed_everything, CIFAR_NORMALIZATION
    import utils.config as cf
    import os
    import torch.backends.cudnn as cudnn
    import time

    seed_everything(args.seed)

    normalize = None
    if args.normalize == "meanstd":
        from torchvision import transforms

        normalize = transforms.Normalize(cf.mean["cifar10"], cf.std["cifar10"])
    elif args.normalize == "default":
        normalize = CIFAR_NORMALIZATION

    # Hyper Parameter settings
    use_cuda = torch.cuda.is_available()
    best_acc = 0
    start_epoch, num_epochs = cf.start_epoch, cf.num_epochs

    # Data Uplaod
    trainloader, testloader = create_loaders(
        args, augment=not args.no_augment, normalize=normalize
    )

    # Model
    print("\n[Phase 2] : Model setup")
    net = Wide_ResNet(**vars(args))
    file_name = os.path.join(
        args.output, "%s/%s/model_%i.pt" % (args.dataset, "wide_resnet", args.seed)
    )
    net.apply(conv_init)

    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss()

    if args.optimizer == "adam":
        from torch.optim import Adam

        optimizer = Adam(net.parameters(), lr=args.lr)
    elif args.optimizer == "sgd":
        from torch.optim import SGD

        optimizer = None
    elif args.optimizer == "sls":
        from utils.sls import Sls

        n_batches_per_epoch = len(trainloader)
        print(n_batches_per_epoch)
        optimizer = Sls(net.parameters(), n_batches_per_epoch=n_batches_per_epoch)
    else:
        raise ValueError("Only supports adam or sgd for optimizer.")

    # Training
    def train(epoch, optimizer=None):
        net.train()
        net.training = True
        train_loss = 0
        correct = 0
        total = 0
        if args.optimizer == "sgd":
            optimizer = SGD(
                net.parameters(),
                lr=cf.learning_rate(args.lr, epoch),
                momentum=0.9,
                weight_decay=5e-4,
            )

        print(
            "\n=> Training Epoch #%d, LR=%.4f"
            % (epoch, cf.learning_rate(args.lr, epoch))
        )
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()  # GPU settings
            optimizer.zero_grad()
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)  # Forward Propagation
            loss = criterion(outputs, targets)  # Loss

            if args.optimizer == "sls":

                def closure():
                    output = net(inputs)
                    loss = criterion(output, targets)
                    return loss

                optimizer.step(closure)
            else:
                loss.backward()
                optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

            sys.stdout.write("\r")
            sys.stdout.write(
                "| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%"
                % (
                    epoch,
                    num_epochs,
                    batch_idx + 1,
                    len(trainloader),
                    loss.item(),
                    100.0 * correct / total,
                )
            )
            sys.stdout.flush()

            if logger is not None:
                logger.write(
                    dict(train_accuracy=100.0 * correct / total, loss=loss.item()),
                    epoch,
                )

    def test(epoch, best_acc=0):
        net.eval()
        net.training = False
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                inputs, targets = Variable(inputs), Variable(targets)
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += predicted.eq(targets.data).cpu().sum()

            # Save checkpoint when best model
            acc = 100.0 * correct / total
            if logger is None:
                print(
                    "\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%"
                    % (epoch, loss.item(), acc)
                )
            else:
                logger.write(dict(test_loss=loss.item(), test_accuracy=acc), epoch)

            if acc > best_acc:
                print("| Saving Best model...\t\t\tTop1 = %.2f%%" % (acc))
                state = {
                    "net": net.module if use_cuda else net,
                    "acc": acc,
                    "epoch": epoch,
                }
                dirname = os.path.dirname(file_name)
                if not os.path.exists(dirname):
                    os.makedirs(dirname)
                torch.save(net.state_dict(), file_name)
                best_acc = acc
        return best_acc

    print("\n[Phase 3] : Training model")
    print("| Training Epochs = " + str(num_epochs))
    print("| Initial Learning Rate = " + str(args.lr))

    elapsed_time = 0
    for epoch in range(start_epoch, start_epoch + num_epochs):
        start_time = time.time()

        train(epoch, optimizer)
        best_acc = test(epoch, best_acc)

        epoch_time = time.time() - start_time
        elapsed_time += epoch_time
        print("| Elapsed time : %d:%02d:%02d" % (cf.get_hms(elapsed_time)))

    print("\n[Phase 4] : Testing model")
    print("* Test results : Acc@1 = %.2f%%" % (best_acc))