in LaNAS/one-shot_LaNAS/Evaluate/super_individual_train.py [0:0]
def run():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled=True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
cur_epoch = 0
layers_type = [
'max_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5'
]
# supernet_normal = eval(args.supernet_normal)
# supernet_reduce = eval(args.supernet_reduce)
supernet_normal, supernet_reduce = translator.encoding_to_masks(eval(args.masked_code))
supernet_normal = translator.expend_to_supernet_code(supernet_normal)
supernet_reduce = translator.expend_to_supernet_code(supernet_reduce)
if not continue_train:
print('train from scratch!')
model = Network(supernet_normal, supernet_reduce, layers_type, args.init_ch, CIFAR_CLASSES, args.layers,
args.auxiliary,
steps=len(supernet_normal), multiplier=len(supernet_normal))
# model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
else:
print('train from checkpoints')
# model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
model = Network(supernet_normal, supernet_reduce, layers_type, args.init_ch, CIFAR_CLASSES, args.layers,
args.auxiliary,
steps=len(supernet_normal))
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
checkpoint = torch.load(args.save + '/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
cur_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler = checkpoint['scheduler']
train_transform, valid_transform = utils._data_transforms_cifar10(args, args.cutout_length)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
best_acc = 0.0
for epoch in range(cur_epoch, args.epochs):
print("=====> current epoch:", epoch)
logging.info('=====> current epoch: %d', epoch)
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
# train_acc, train_obj = train(train_queue, model, criterion, optimizer)
train_acc, train_obj = train(train_queue, model, criterion, optimizer, args.grad_clip, args.report_freq)
scheduler.step()
logging.info('train_acc %f', train_acc)
# valid_acc, valid_obj = infer(valid_queue, model, criterion)
valid_acc, valid_obj = infer(valid_queue, model, criterion, args.report_freq)
logging.info('valid_acc %f', valid_acc)
if valid_acc > best_acc:
best_acc = valid_acc
print('this model is the best')
logging.info('this model is the best')
torch.save({'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'scheduler': scheduler,
'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'top_1.pt'))
torch.save({'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'scheduler': scheduler,
'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'model.pt'))
logging.info('best_acc: %f', best_acc)
print('current best acc is', best_acc)