in imagenet.py [0:0]
def main():
torch.manual_seed(1)
np.random.seed(1)
# Load datasets
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
preprocess = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(mean, std)])
if 'augmix' in args.exp_id:
print('using augmix data preprocessing...')
train_transform = transforms.Compose(
[transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip()])
else:
print('using only standard data preprocessing...')
train_transform = transforms.Compose(
[transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
preprocess])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
preprocess,
])
traindir = os.path.join(args.data_dir, 'train')
valdir = os.path.join(args.data_dir, 'validation')
train_dataset = datasets.ImageFolder(traindir, train_transform)
assert os.path.isdir(args.corrupt_data_dir)
if 'augmix' in args.exp_id:
train_dataset = AugMixDataset(train_dataset, preprocess, all_ops=False, mixture_width=3,
mixture_depth=-1, aug_severity=1, no_jsd=False, image_size=224)
# print('batch_size: {}'.format(args.batch_size))
# print('workers: {}'.format(args.workers))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
test_dataset = datasets.ImageFolder(valdir, test_transform)
val_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1000,
shuffle=False,
num_workers=args.workers,
pin_memory=True)
print('model: {}'.format(args.model))
if args.model == 'resnet50':
net = resnet50(args)
elif args.model == 'resnet50_ibn_a':
net = resnet50_ibn_a(args)
elif args.model == 'resnet50_ibn_b':
net = resnet50_ibn_b(args)
para_num = sum(p.numel() for p in net.parameters())
print('model param #: {}'.format(para_num))
# exit()
if args.pretrained:
print('pretrained model: {}'.format(args.pretrained))
state_dict = torch.load(args.pretrained)
net.load_state_dict(state_dict, strict=False)
print('optimizer momentum: {}'.format(args.momentum))
print('optimizer weight_decay: {}'.format(args.weight_decay))
optimizer = torch.optim.SGD(
net.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# Distribute model across all visible GPUs
net = torch.nn.DataParallel(net).cuda()
cudnn.benchmark = True
start_epoch = 0
if args.resume:
# print('resume checkpoint: {}'.format(args.resume))
exp_dir_idx = args.resume.rindex('/')
exp_dir = args.resume[:exp_dir_idx]
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_acc']
net.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
# print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
# print('exp_dir: {}'.format(exp_dir))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# best_val_acc, test_acc, start_epoch = \
# utils.load_checkpoint(args, model, optimizer)
else:
start_epoch = 0
best_acc = 0.
exp_dir = get_log_dir_path(args.exp_dir, args.exp_id)
if not os.path.exists(exp_dir):
os.makedirs(exp_dir)
if args.evaluate:
test_loss, test_acc1 = test(net, val_loader)
print('Clean\n\tTest Loss {:.3f} | Test Acc1 {:.3f}'.format(
test_loss, 100 * test_acc1))
# exit()
corruption_accs = test_c(net, test_transform)
for c in CORRUPTIONS:
print('\t'.join(map(str, [c] + corruption_accs[c])))
mce, ce_dict = compute_mce(corruption_accs)
print_ces(ce_dict)
print('mCE (normalized by AlexNet): ', mce)
return
print('exp_dir: {}'.format(exp_dir))
log_file = os.path.join(exp_dir, 'log.txt')
names = ['epoch', 'lr', 'Train Err1', 'Test Err1' 'Best Test Err1']
with open(log_file, 'a') as f:
f.write('batch size: {}\n'.format(args.batch_size))
f.write('lr: {}\n'.format(args.lr))
f.write('momentum: {}\n'.format(args.momentum))
f.write('weight_decay: {}\n'.format(args.weight_decay))
for per_name in names:
f.write(per_name + '\t')
f.write('\n')
# print('=> Training the base model')
print('start_epoch {}'.format(start_epoch))
print('total epochs: {}'.format(args.epochs))
print('best_acc: {}'.format(best_acc))
# print('best_err5: {}'.format(best_err5))
print('Beginning training from epoch:', start_epoch)
if args.cn_prob:
print('cn_prob: {}'.format(args.cn_prob))
if args.consist_wt:
print('consist_wt: {}'.format(args.consist_wt))
for epoch in range(start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch)
lr = optimizer.param_groups[0]['lr']
print('lr: {}'.format(lr))
if 'augmix' in args.exp_id: # for CrossNorm in image space, 'cn' is not in cnsn_type
assert args.cn_prob > 0
train_err1 = train_cn_image_augmix(net, train_loader, optimizer)
elif 'consist' in args.exp_id: # for CrossNorm in image space, 'cn' is not in cnsn_type
assert args.cn_prob > 0
train_err1 = train_cn_image_consist(net, train_loader, optimizer)
elif 'cn' in args.exp_id: # for CrossNorm in image space, 'cn' is not in cnsn_type
assert args.cn_prob > 0
train_err1 = train_cn_image(net, train_loader, optimizer)
else:
train_err1 = train(net, train_loader, optimizer)
test_loss, test_acc = test(net, val_loader)
is_best = test_acc > best_acc
best_acc = max(test_acc, best_acc)
save_checkpoint(net, {
'epoch': epoch + 1,
'state_dict': net.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
}, is_best, exp_dir, epoch=epoch)
values = [train_err1, 100 - 100. * test_acc, 100 - 100. * best_acc]
with open(log_file, 'a') as f:
f.write('{:d}\t'.format(epoch))
f.write('{:g}\t'.format(lr))
for per_value in values:
f.write('{:2.2f}\t'.format(per_value))
f.write('\n')
print('exp_dir: {}'.format(exp_dir))
corruption_accs = test_c(net, test_transform)
for c in CORRUPTIONS:
print('\t'.join(map(str, [c] + corruption_accs[c])))
mce, ce_dict = compute_mce(corruption_accs)
print_ces(ce_dict)
print('mCE (normalized by AlexNet): {:.2f}'.format(mce))
with open(log_file, 'a') as f:
f.write('individual corruption errors: \n')
for per in CORRUPTIONS:
f.write('{0}: {ce:.2f}\n'.format(per, ce=ce_dict[per]))
f.write('mCE: {:.2f}\t'.format(mce))
f.write('\n')