def main()

in LaNAS/LaNet/CIFAR10/train.py [0:0]


def main():


    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    cur_epoch = 0

    net = eval(args.arch)
    print(net)
    code = gen_code_from_list(net, node_num=int((len(net) / 4)))
    genotype = translator([code, code], max_node=int((len(net) / 4)))
    print(genotype)

    model_ema = None

    if not continue_train:

        print('train from the scratch')
        model = Network(args.init_ch, 10, args.layers, args.auxiliary, genotype).cuda()
        print("model init params values:", flatten_params(model))

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))


        criterion = CutMixCrossEntropyLoss(True).cuda()


        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            weight_decay=args.wd
        )

        if args.model_ema:
            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '')


    else:
        print('continue train from checkpoint')

        model = Network(args.init_ch, 10, args.layers, args.auxiliary, genotype).cuda()

        criterion = CutMixCrossEntropyLoss(True).cuda()

        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            weight_decay=args.wd
        )

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        checkpoint = torch.load(args.save+'/model.pt')
        model.load_state_dict(checkpoint['model_state_dict'])
        cur_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if args.model_ema:

            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '',
                resume=args.save + '/model.pt')


    train_transform, valid_transform = utils._auto_data_transforms_cifar10(args)

    ds_train = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)


    args.cv = -1
    if args.cv >= 0:
        sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
        sss = sss.split(list(range(len(ds_train))), ds_train.targets)
        for _ in range(args.cv + 1):
            train_idx, valid_idx = next(sss)
        ds_valid = Subset(ds_train, valid_idx)
        ds_train = Subset(ds_train, train_idx)
    else:
        ds_valid = Subset(ds_train, [])

    train_queue = torch.utils.data.DataLoader(
        CutMix(ds_train, 10,
               beta=1.0, prob=0.5, num_mix=2),
        batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)

    valid_queue = torch.utils.data.DataLoader(
        dset.CIFAR10(root=args.data, train=False, transform=valid_transform),
        batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)



    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))

    best_acc = 0.0

    if continue_train:
        for i in range(cur_epoch+1):
            scheduler.step()

    for epoch in range(cur_epoch, args.epochs):
        print('cur_epoch is', epoch)
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        if model_ema is not None:
            model_ema.ema.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer, epoch, model_ema)
        logging.info('train_acc: %f', train_acc)

        if model_ema is not None and not args.model_ema_force_cpu:
            valid_acc_ema, valid_obj_ema = infer(valid_queue, model_ema.ema, criterion, ema=True)
            logging.info('valid_acc_ema %f', valid_acc_ema)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)


        if valid_acc > best_acc:
            best_acc = valid_acc
            print('this model is the best')
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'top1.pt'))
        print('current best acc is', best_acc)
        logging.info('best_acc: %f', best_acc)

        if model_ema is not None:
            torch.save(
                {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
                 'state_dict_ema': get_state_dict(model_ema)},
                os.path.join(args.save, 'model.pt'))

        else:
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, os.path.join(args.save, 'model.pt'))

        print('saved to: trained.pt')