in LaNAS/Distributed_LaNAS/clientX/continue_train.py [0:0]
def main():
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
cudnn.enabled = True
torch.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
# model = Network(args.init_ch, 10, args.layers, args.auxiliary, genotype).cuda()
model = torch.load(os.path.join(args.model_path, 'model.pt'))
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
best_acc = 0.0
for i in range(args.cur_epoch):
scheduler.step()
for epoch in range(args.cur_epoch, args.epochs):
scheduler.step()
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc: %f', valid_acc)
if valid_acc > best_acc:
best_acc = valid_acc
print('this model is the best')
torch.save(model, os.path.join(args.save, 'AlphaX_1.pt'))
torch.save(model, os.path.join(args.save, 'trained.pt'))
print('current best acc is', best_acc)
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
logging.info('train_acc: %f', train_acc)
# utils.save(model, os.path.join(args.save, 'trained.pt'))
print('saved to: trained.pt')