def main()

in LaNAS/Distributed_LaNAS/clientX/continue_train.py [0:0]


def main():


    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)

    # model = Network(args.init_ch, 10, args.layers, args.auxiliary, genotype).cuda()
    model = torch.load(os.path.join(args.model_path, 'model.pt'))

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.wd
    )

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

    best_acc = 0.0

    for i in range(args.cur_epoch):
        scheduler.step()

    for epoch in range(args.cur_epoch, args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs



        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        if valid_acc > best_acc:
            best_acc = valid_acc
            print('this model is the best')
            torch.save(model, os.path.join(args.save, 'AlphaX_1.pt'))

        torch.save(model, os.path.join(args.save, 'trained.pt'))
        print('current best acc is', best_acc)


        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc: %f', train_acc)



        # utils.save(model, os.path.join(args.save, 'trained.pt'))
        print('saved to: trained.pt')