in cifar.py [0:0]
def main():
torch.manual_seed(1)
np.random.seed(1)
# datasets
if 'augmix' in args.exp_id:
train_transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4)])
else:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5] * 3, [0.5] * 3),
])
preprocess = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5] * 3, [0.5] * 3)])
test_transform = preprocess
if args.dataset.lower() == 'cifar-10':
print('using cifar-10 data ...')
train_data = datasets.CIFAR10(
root=args.data_dir, train=True, transform=train_transform, download=True)
test_data = datasets.CIFAR10(
root=args.data_dir, train=False, transform=test_transform, download=True)
base_c_path = args.corrupt_data_dir
num_classes = 10
elif args.dataset.lower() == 'cifar-100':
print('using cifar-100 data ...')
train_data = datasets.CIFAR100(
root=args.data_dir, train=True, transform=train_transform, download=True)
test_data = datasets.CIFAR100(
root=args.data_dir, train=False, transform=test_transform, download=True)
base_c_path = args.corrupt_data_dir
num_classes = 100
else:
raise Exception('unknown dataset: {}'.format(args.dataset))
assert os.path.isdir(base_c_path)
if 'augmix' in args.exp_id:
train_data = AugMixDataset(train_data, preprocess, all_ops=False, mixture_width=3,
mixture_depth=-1, aug_severity=3, no_jsd=False, image_size=32)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=1000,
shuffle=False,
num_workers=args.workers,
pin_memory=True)
# model
print('model: {}'.format(args.model))
if args.model == 'wideresnet':
net = WideResNet(40, num_classes=num_classes, widen_factor=2, drop_rate=0,
active_num=args.active_num, pos=args.pos,
beta=args.beta, crop=args.crop, cnsn_type=args.cnsn_type)
elif args.model == 'allconv':
net = AllConvNet(num_classes, active_num=args.active_num, pos=args.pos,
beta=args.beta, crop=args.crop,
cnsn_type=args.cnsn_type)
elif args.model == 'resnext':
net = resnext29(num_classes=num_classes, config=args)
elif args.model == 'densenet':
net = densenet(num_classes=num_classes, config=args)
else:
raise Exception('unkown model: {}'.format(args.model))
para_num = sum(p.numel() for p in net.parameters())
print('model param #: {}'.format(para_num))
net = torch.nn.DataParallel(net).cuda()
cudnn.benchmark = True
# optimizer
optimizer = optim.SGD(net.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True)
for group in optimizer.param_groups:
print('lr: {}, weight_decay: {}, momentum: {}, nesterov: {}'
.format(group['lr'], group['weight_decay'], group['momentum'], group['nesterov']))
# lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda
step,
args.epochs * len(train_loader),
1, # lr_lambda computes multiplicative factor
1e-6 / args.lr))
if args.resume:
# print_logits(net, train_loader, 100)
print('resume checkpoint: {}'.format(args.resume))
exp_dir_idx = args.resume.rindex('/')
exp_dir = args.resume[:exp_dir_idx]
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_acc']
net.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
# print('exp_dir: {}'.format(exp_dir))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
start_epoch = 0
best_acc = 0.
exp_dir = get_log_dir_path(args.exp_dir, args.exp_id)
if not os.path.exists(exp_dir):
os.makedirs(exp_dir)
if args.evaluate:
# Evaluate clean accuracy first because test_c mutates underlying data
test_loss, test_acc = test(net, test_loader)
print('Clean\n\tTest Loss {:.3f} | Test Error {:.2f}'.format(
test_loss, 100 - 100. * test_acc))
test_c_acc = test_c(net, test_data, base_c_path)
print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))
return
print('exp_dir: {}'.format(exp_dir))
log_file = os.path.join(exp_dir, 'log.txt')
names = ['epoch', 'lr', 'Train Loss', 'Test Err1' 'Best Test Err1']
with open(log_file, 'a') as f:
f.write('dataset: {}\n'.format(args.dataset))
f.write('batch size: {}\n'.format(args.batch_size))
f.write('lr: {}\n'.format(args.lr))
f.write('momentum: {}\n'.format(args.momentum))
f.write('weight_decay: {}\n'.format(args.weight_decay))
for per_name in names:
f.write(per_name + '\t')
f.write('\n')
# print('=> Training the base model')
print('start_epoch {}'.format(start_epoch))
print('total epochs: {}'.format(args.epochs))
print('best_acc: {}'.format(best_acc))
# print('best_err5: {}'.format(best_err5))
if args.cn_prob:
print('cn_prob: {}'.format(args.cn_prob))
if args.consist_wt:
print('consist_wt: {}'.format(args.consist_wt))
for epoch in range(start_epoch, args.epochs):
lr = optimizer.param_groups[0]['lr']
if 'augmix' in args.exp_id and 'cn' in args.cnsn_type:
assert args.cn_prob > 0 and args.consist_wt > 0
train_loss_ema = train_cn_augmix(net, train_loader, optimizer, scheduler)
elif 'consist' in args.exp_id and 'cn' in args.cnsn_type:
assert args.cn_prob > 0 and args.consist_wt > 0
train_loss_ema = train_cn_consistency(net, train_loader, optimizer, scheduler)
elif 'cn' in args.cnsn_type:
assert args.cn_prob > 0
train_loss_ema = train_cn(net, train_loader, optimizer, scheduler)
else:
train_loss_ema = train(net, train_loader, optimizer, scheduler)
test_loss, test_acc = test(net, test_loader)
# test_c_acc = test_c(net, test_data, base_c_path)
is_best = test_acc > best_acc
best_acc = max(test_acc, best_acc)
save_checkpoint(net, {
'epoch': epoch + 1,
'state_dict': net.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
}, is_best, exp_dir, epoch=None)
values = [train_loss_ema, 100 - 100. * test_acc, 100 - 100. * best_acc]
with open(log_file, 'a') as f:
f.write('{:d}\t'.format(epoch))
f.write('{:g}\t'.format(lr))
for per_value in values:
f.write('{:2.2f}\t'.format(per_value))
f.write('\n')
print('exp_dir: {}'.format(exp_dir))
test_c_acc = test_c(net, test_data, base_c_path)
print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))
with open(log_file, 'a') as f:
f.write('{:2.2f}\t'.format(100 - 100. * test_c_acc))
f.write('\n')