def main()

in cifar.py [0:0]


def main():
    torch.manual_seed(1)
    np.random.seed(1)

    # datasets
    if 'augmix' in args.exp_id:
        train_transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(),
             transforms.RandomCrop(32, padding=4)])
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5] * 3, [0.5] * 3),
        ])

    preprocess = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5] * 3, [0.5] * 3)])
    test_transform = preprocess

    if args.dataset.lower() == 'cifar-10':
        print('using cifar-10 data ...')
        train_data = datasets.CIFAR10(
            root=args.data_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(
            root=args.data_dir, train=False, transform=test_transform, download=True)
        base_c_path = args.corrupt_data_dir
        num_classes = 10
    elif args.dataset.lower() == 'cifar-100':
        print('using cifar-100 data ...')
        train_data = datasets.CIFAR100(
            root=args.data_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(
            root=args.data_dir, train=False, transform=test_transform, download=True)
        base_c_path = args.corrupt_data_dir
        num_classes = 100
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    assert os.path.isdir(base_c_path)
    if 'augmix' in args.exp_id:
        train_data = AugMixDataset(train_data, preprocess, all_ops=False, mixture_width=3,
                                   mixture_depth=-1, aug_severity=3, no_jsd=False, image_size=32)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=1000,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    # model
    print('model: {}'.format(args.model))
    if args.model == 'wideresnet':
        net = WideResNet(40, num_classes=num_classes, widen_factor=2, drop_rate=0,
                         active_num=args.active_num, pos=args.pos,
                         beta=args.beta, crop=args.crop, cnsn_type=args.cnsn_type)
    elif args.model == 'allconv':
        net = AllConvNet(num_classes, active_num=args.active_num, pos=args.pos,
                         beta=args.beta, crop=args.crop,
                         cnsn_type=args.cnsn_type)
    elif args.model == 'resnext':
        net = resnext29(num_classes=num_classes, config=args)
    elif args.model == 'densenet':
        net = densenet(num_classes=num_classes, config=args)
    else:
        raise Exception('unkown model: {}'.format(args.model))

    para_num = sum(p.numel() for p in net.parameters())
    print('model param #: {}'.format(para_num))

    net = torch.nn.DataParallel(net).cuda()
    cudnn.benchmark = True

    # optimizer
    optimizer = optim.SGD(net.parameters(), args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay,
                          nesterov=True)
    for group in optimizer.param_groups:
        print('lr: {}, weight_decay: {}, momentum: {}, nesterov: {}'
              .format(group['lr'], group['weight_decay'], group['momentum'], group['nesterov']))

    # lr scheduler
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
            step,
            args.epochs * len(train_loader),
            1,  # lr_lambda computes multiplicative factor
            1e-6 / args.lr))

    if args.resume:
        # print_logits(net, train_loader, 100)
        print('resume checkpoint: {}'.format(args.resume))
        exp_dir_idx = args.resume.rindex('/')
        exp_dir = args.resume[:exp_dir_idx]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
            # print('exp_dir: {}'.format(exp_dir))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        start_epoch = 0
        best_acc = 0.
        exp_dir = get_log_dir_path(args.exp_dir, args.exp_id)
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)

    if args.evaluate:
        # Evaluate clean accuracy first because test_c mutates underlying data
        test_loss, test_acc = test(net, test_loader)
        print('Clean\n\tTest Loss {:.3f} | Test Error {:.2f}'.format(
            test_loss, 100 - 100. * test_acc))

        test_c_acc = test_c(net, test_data, base_c_path)
        print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))
        return

    print('exp_dir: {}'.format(exp_dir))
    log_file = os.path.join(exp_dir, 'log.txt')
    names = ['epoch', 'lr', 'Train Loss', 'Test Err1' 'Best Test Err1']
    with open(log_file, 'a') as f:
        f.write('dataset: {}\n'.format(args.dataset))
        f.write('batch size: {}\n'.format(args.batch_size))
        f.write('lr: {}\n'.format(args.lr))
        f.write('momentum: {}\n'.format(args.momentum))
        f.write('weight_decay: {}\n'.format(args.weight_decay))
        for per_name in names:
            f.write(per_name + '\t')
        f.write('\n')
    # print('=> Training the base model')
    print('start_epoch {}'.format(start_epoch))
    print('total epochs: {}'.format(args.epochs))
    print('best_acc: {}'.format(best_acc))
    # print('best_err5: {}'.format(best_err5))

    if args.cn_prob:
        print('cn_prob: {}'.format(args.cn_prob))
    if args.consist_wt:
        print('consist_wt: {}'.format(args.consist_wt))
    for epoch in range(start_epoch, args.epochs):
        lr = optimizer.param_groups[0]['lr']

        if 'augmix' in args.exp_id and 'cn' in args.cnsn_type:
            assert args.cn_prob > 0 and args.consist_wt > 0
            train_loss_ema = train_cn_augmix(net, train_loader, optimizer, scheduler)
        elif 'consist' in args.exp_id and 'cn' in args.cnsn_type:
            assert args.cn_prob > 0 and args.consist_wt > 0
            train_loss_ema = train_cn_consistency(net, train_loader, optimizer, scheduler)
        elif 'cn' in args.cnsn_type:
            assert args.cn_prob > 0
            train_loss_ema = train_cn(net, train_loader, optimizer, scheduler)
        else:
            train_loss_ema = train(net, train_loader, optimizer, scheduler)

        test_loss, test_acc = test(net, test_loader)
        # test_c_acc = test_c(net, test_data, base_c_path)

        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        save_checkpoint(net, {
            'epoch': epoch + 1,
            'state_dict': net.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, is_best, exp_dir, epoch=None)

        values = [train_loss_ema, 100 - 100. * test_acc, 100 - 100. * best_acc]
        with open(log_file, 'a') as f:
            f.write('{:d}\t'.format(epoch))
            f.write('{:g}\t'.format(lr))
            for per_value in values:
                f.write('{:2.2f}\t'.format(per_value))
            f.write('\n')
        print('exp_dir: {}'.format(exp_dir))

    test_c_acc = test_c(net, test_data, base_c_path)
    print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))
    with open(log_file, 'a') as f:
        f.write('{:2.2f}\t'.format(100 - 100. * test_c_acc))
        f.write('\n')