def run()

in LaNAS/Distributed_LaNAS/clientX/train_client.py [0:0]


def run(net, init_ch=32, layers=20, auxiliary=True, lr=0.025, momentum=0.9, wd=3e-4, cutout=True, cutout_length=16, data='../data', batch_size=96, epochs=600, drop_path_prob=0.2, auxiliary_weight=0.4):
    save = '/checkpoint/linnanwang/nasnet/' + hashlib.md5(json.dumps(net).encode()).hexdigest()
    utils.create_exp_dir(save, scripts_to_save=glob.glob('*.py'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)


    np.random.seed(0)
    torch.cuda.set_device(0)
    cudnn.benchmark = True
    cudnn.enabled = True
    torch.manual_seed(0)
    logging.info('gpu device = %d' % 0)
    # logging.info("args = %s", args)


    genotype = net
    model = Network(init_ch, 10, layers, auxiliary, genotype).cuda()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr,
        momentum=momentum,
        weight_decay=wd
    )
    model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O3")



    train_transform, valid_transform = utils._data_transforms_cifar10(cutout, cutout_length)
    train_data = dset.CIFAR10(root=data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=data, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epochs))

    best_acc = 0.0

    for epoch in range(epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = drop_path_prob * epoch / epochs


        train_acc, train_obj = train(train_queue, model, criterion, optimizer, auxiliary=auxiliary, auxiliary_weight=auxiliary_weight)
        logging.info('train_acc: %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        if valid_acc > best_acc and epoch >= 50:
            print('this model is the best')
            torch.save(model.state_dict(), os.path.join(save, 'model.pt'))
        if valid_acc > best_acc:
            best_acc = valid_acc
        print('current best acc is', best_acc)

        if epoch == 100:
            break

        # utils.save(model, os.path.join(args.save, 'trained.pt'))
        print('saved to: model.pt')

    return best_acc