in online_attacks/classifiers/cifar/models/wide_resnet.py [0:0]
def train(args, logger=None):
from utils.utils import create_loaders, seed_everything, CIFAR_NORMALIZATION
import utils.config as cf
import os
import torch.backends.cudnn as cudnn
import time
seed_everything(args.seed)
normalize = None
if args.normalize == "meanstd":
from torchvision import transforms
normalize = transforms.Normalize(cf.mean["cifar10"], cf.std["cifar10"])
elif args.normalize == "default":
normalize = CIFAR_NORMALIZATION
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
start_epoch, num_epochs = cf.start_epoch, cf.num_epochs
# Data Uplaod
trainloader, testloader = create_loaders(
args, augment=not args.no_augment, normalize=normalize
)
# Model
print("\n[Phase 2] : Model setup")
net = Wide_ResNet(**vars(args))
file_name = os.path.join(
args.output, "%s/%s/model_%i.pt" % (args.dataset, "wide_resnet", args.seed)
)
net.apply(conv_init)
if use_cuda:
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
criterion = nn.CrossEntropyLoss()
if args.optimizer == "adam":
from torch.optim import Adam
optimizer = Adam(net.parameters(), lr=args.lr)
elif args.optimizer == "sgd":
from torch.optim import SGD
optimizer = None
elif args.optimizer == "sls":
from utils.sls import Sls
n_batches_per_epoch = len(trainloader)
print(n_batches_per_epoch)
optimizer = Sls(net.parameters(), n_batches_per_epoch=n_batches_per_epoch)
else:
raise ValueError("Only supports adam or sgd for optimizer.")
# Training
def train(epoch, optimizer=None):
net.train()
net.training = True
train_loss = 0
correct = 0
total = 0
if args.optimizer == "sgd":
optimizer = SGD(
net.parameters(),
lr=cf.learning_rate(args.lr, epoch),
momentum=0.9,
weight_decay=5e-4,
)
print(
"\n=> Training Epoch #%d, LR=%.4f"
% (epoch, cf.learning_rate(args.lr, epoch))
)
for batch_idx, (inputs, targets) in enumerate(trainloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
optimizer.zero_grad()
inputs, targets = Variable(inputs), Variable(targets)
outputs = net(inputs) # Forward Propagation
loss = criterion(outputs, targets) # Loss
if args.optimizer == "sls":
def closure():
output = net(inputs)
loss = criterion(output, targets)
return loss
optimizer.step(closure)
else:
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
sys.stdout.write("\r")
sys.stdout.write(
"| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%"
% (
epoch,
num_epochs,
batch_idx + 1,
len(trainloader),
loss.item(),
100.0 * correct / total,
)
)
sys.stdout.flush()
if logger is not None:
logger.write(
dict(train_accuracy=100.0 * correct / total, loss=loss.item()),
epoch,
)
def test(epoch, best_acc=0):
net.eval()
net.training = False
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
# Save checkpoint when best model
acc = 100.0 * correct / total
if logger is None:
print(
"\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%"
% (epoch, loss.item(), acc)
)
else:
logger.write(dict(test_loss=loss.item(), test_accuracy=acc), epoch)
if acc > best_acc:
print("| Saving Best model...\t\t\tTop1 = %.2f%%" % (acc))
state = {
"net": net.module if use_cuda else net,
"acc": acc,
"epoch": epoch,
}
dirname = os.path.dirname(file_name)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(net.state_dict(), file_name)
best_acc = acc
return best_acc
print("\n[Phase 3] : Training model")
print("| Training Epochs = " + str(num_epochs))
print("| Initial Learning Rate = " + str(args.lr))
elapsed_time = 0
for epoch in range(start_epoch, start_epoch + num_epochs):
start_time = time.time()
train(epoch, optimizer)
best_acc = test(epoch, best_acc)
epoch_time = time.time() - start_time
elapsed_time += epoch_time
print("| Elapsed time : %d:%02d:%02d" % (cf.get_hms(elapsed_time)))
print("\n[Phase 4] : Testing model")
print("* Test results : Acc@1 = %.2f%%" % (best_acc))