in LaNAS/LaNet/CIFAR10/train.py [0:0]
def main():
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
cudnn.enabled = True
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
cur_epoch = 0
net = eval(args.arch)
print(net)
code = gen_code_from_list(net, node_num=int((len(net) / 4)))
genotype = translator([code, code], max_node=int((len(net) / 4)))
print(genotype)
model_ema = None
if not continue_train:
print('train from the scratch')
model = Network(args.init_ch, 10, args.layers, args.auxiliary, genotype).cuda()
print("model init params values:", flatten_params(model))
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = CutMixCrossEntropyLoss(True).cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
if args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '')
else:
print('continue train from checkpoint')
model = Network(args.init_ch, 10, args.layers, args.auxiliary, genotype).cuda()
criterion = CutMixCrossEntropyLoss(True).cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
checkpoint = torch.load(args.save+'/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
cur_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.save + '/model.pt')
train_transform, valid_transform = utils._auto_data_transforms_cifar10(args)
ds_train = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
args.cv = -1
if args.cv >= 0:
sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
sss = sss.split(list(range(len(ds_train))), ds_train.targets)
for _ in range(args.cv + 1):
train_idx, valid_idx = next(sss)
ds_valid = Subset(ds_train, valid_idx)
ds_train = Subset(ds_train, train_idx)
else:
ds_valid = Subset(ds_train, [])
train_queue = torch.utils.data.DataLoader(
CutMix(ds_train, 10,
beta=1.0, prob=0.5, num_mix=2),
batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)
valid_queue = torch.utils.data.DataLoader(
dset.CIFAR10(root=args.data, train=False, transform=valid_transform),
batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
best_acc = 0.0
if continue_train:
for i in range(cur_epoch+1):
scheduler.step()
for epoch in range(cur_epoch, args.epochs):
print('cur_epoch is', epoch)
scheduler.step()
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
if model_ema is not None:
model_ema.ema.drop_path_prob = args.drop_path_prob * epoch / args.epochs
train_acc, train_obj = train(train_queue, model, criterion, optimizer, epoch, model_ema)
logging.info('train_acc: %f', train_acc)
if model_ema is not None and not args.model_ema_force_cpu:
valid_acc_ema, valid_obj_ema = infer(valid_queue, model_ema.ema, criterion, ema=True)
logging.info('valid_acc_ema %f', valid_acc_ema)
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc: %f', valid_acc)
if valid_acc > best_acc:
best_acc = valid_acc
print('this model is the best')
torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'top1.pt'))
print('current best acc is', best_acc)
logging.info('best_acc: %f', best_acc)
if model_ema is not None:
torch.save(
{'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'state_dict_ema': get_state_dict(model_ema)},
os.path.join(args.save, 'model.pt'))
else:
torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'model.pt'))
print('saved to: trained.pt')