def main()

in imagenet.py [0:0]


def main():
  torch.manual_seed(1)
  np.random.seed(1)

  # Load datasets
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  preprocess = transforms.Compose(
      [transforms.ToTensor(),
       transforms.Normalize(mean, std)])
  if 'augmix' in args.exp_id:
      print('using augmix data preprocessing...')
      train_transform = transforms.Compose(
          [transforms.RandomResizedCrop(224),
           transforms.RandomHorizontalFlip()])
  else:
      print('using only standard data preprocessing...')
      train_transform = transforms.Compose(
          [transforms.RandomResizedCrop(224),
           transforms.RandomHorizontalFlip(),
           preprocess])

  test_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      preprocess,
  ])

  traindir = os.path.join(args.data_dir, 'train')
  valdir = os.path.join(args.data_dir, 'validation')
  train_dataset = datasets.ImageFolder(traindir, train_transform)

  assert os.path.isdir(args.corrupt_data_dir)
  if 'augmix' in args.exp_id:
    train_dataset = AugMixDataset(train_dataset, preprocess, all_ops=False, mixture_width=3,
                                  mixture_depth=-1, aug_severity=1, no_jsd=False, image_size=224)
  # print('batch_size: {}'.format(args.batch_size))
  # print('workers: {}'.format(args.workers))
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=args.batch_size,
      shuffle=True,
      num_workers=args.workers,
      pin_memory=True)
  test_dataset = datasets.ImageFolder(valdir, test_transform)

  val_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=1000,
      shuffle=False,
      num_workers=args.workers,
      pin_memory=True)

  print('model: {}'.format(args.model))
  if args.model == 'resnet50':
      net = resnet50(args)
  elif args.model == 'resnet50_ibn_a':
      net = resnet50_ibn_a(args)
  elif args.model == 'resnet50_ibn_b':
      net = resnet50_ibn_b(args)

  para_num = sum(p.numel() for p in net.parameters())
  print('model param #: {}'.format(para_num))
  # exit()

  if args.pretrained:
      print('pretrained model: {}'.format(args.pretrained))
      state_dict = torch.load(args.pretrained)
      net.load_state_dict(state_dict, strict=False)

  print('optimizer momentum: {}'.format(args.momentum))
  print('optimizer weight_decay: {}'.format(args.weight_decay))

  optimizer = torch.optim.SGD(
      net.parameters(),
      args.lr,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  # Distribute model across all visible GPUs
  net = torch.nn.DataParallel(net).cuda()
  cudnn.benchmark = True

  start_epoch = 0

  if args.resume:
      # print('resume checkpoint: {}'.format(args.resume))
      exp_dir_idx = args.resume.rindex('/')
      exp_dir = args.resume[:exp_dir_idx]
      if os.path.isfile(args.resume):
          print("=> loading checkpoint '{}'".format(args.resume))
          checkpoint = torch.load(args.resume)
          start_epoch = checkpoint['epoch']
          best_acc = checkpoint['best_acc']
          net.load_state_dict(checkpoint['state_dict'])
          optimizer.load_state_dict(checkpoint['optimizer'])
          # print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
          # print('exp_dir: {}'.format(exp_dir))
      else:
          print("=> no checkpoint found at '{}'".format(args.resume))
      # best_val_acc, test_acc, start_epoch = \
      #     utils.load_checkpoint(args, model, optimizer)

  else:
      start_epoch = 0
      best_acc = 0.
      exp_dir = get_log_dir_path(args.exp_dir, args.exp_id)
      if not os.path.exists(exp_dir):
          os.makedirs(exp_dir)

  if args.evaluate:
    test_loss, test_acc1 = test(net, val_loader)
    print('Clean\n\tTest Loss {:.3f} | Test Acc1 {:.3f}'.format(
        test_loss, 100 * test_acc1))
    # exit()
    corruption_accs = test_c(net, test_transform)
    for c in CORRUPTIONS:
        print('\t'.join(map(str, [c] + corruption_accs[c])))

    mce, ce_dict = compute_mce(corruption_accs)
    print_ces(ce_dict)
    print('mCE (normalized by AlexNet): ', mce)
    return

  print('exp_dir: {}'.format(exp_dir))
  log_file = os.path.join(exp_dir, 'log.txt')
  names = ['epoch', 'lr', 'Train Err1', 'Test Err1' 'Best Test Err1']
  with open(log_file, 'a') as f:
      f.write('batch size: {}\n'.format(args.batch_size))
      f.write('lr: {}\n'.format(args.lr))
      f.write('momentum: {}\n'.format(args.momentum))
      f.write('weight_decay: {}\n'.format(args.weight_decay))
      for per_name in names:
          f.write(per_name + '\t')
      f.write('\n')
  # print('=> Training the base model')
  print('start_epoch {}'.format(start_epoch))
  print('total epochs: {}'.format(args.epochs))
  print('best_acc: {}'.format(best_acc))
  # print('best_err5: {}'.format(best_err5))
  print('Beginning training from epoch:', start_epoch)

  if args.cn_prob:
      print('cn_prob: {}'.format(args.cn_prob))
  if args.consist_wt:
      print('consist_wt: {}'.format(args.consist_wt))

  for epoch in range(start_epoch, args.epochs):
    adjust_learning_rate(optimizer, epoch)
    lr = optimizer.param_groups[0]['lr']
    print('lr: {}'.format(lr))

    if 'augmix' in args.exp_id:  # for CrossNorm in image space, 'cn' is not in cnsn_type
        assert args.cn_prob > 0
        train_err1 = train_cn_image_augmix(net, train_loader, optimizer)
    elif 'consist' in args.exp_id:  # for CrossNorm in image space, 'cn' is not in cnsn_type
        assert args.cn_prob > 0
        train_err1 = train_cn_image_consist(net, train_loader, optimizer)
    elif 'cn' in args.exp_id:  # for CrossNorm in image space, 'cn' is not in cnsn_type
        assert args.cn_prob > 0
        train_err1 = train_cn_image(net, train_loader, optimizer)
    else:
        train_err1 = train(net, train_loader, optimizer)

    test_loss, test_acc = test(net, val_loader)

    is_best = test_acc > best_acc
    best_acc = max(test_acc, best_acc)

    save_checkpoint(net, {
        'epoch': epoch + 1,
        'state_dict': net.state_dict(),
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, is_best, exp_dir, epoch=epoch)

    values = [train_err1, 100 - 100. * test_acc, 100 - 100. * best_acc]
    with open(log_file, 'a') as f:
        f.write('{:d}\t'.format(epoch))
        f.write('{:g}\t'.format(lr))
        for per_value in values:
            f.write('{:2.2f}\t'.format(per_value))
        f.write('\n')
    print('exp_dir: {}'.format(exp_dir))

  corruption_accs = test_c(net, test_transform)
  for c in CORRUPTIONS:
    print('\t'.join(map(str, [c] + corruption_accs[c])))

  mce, ce_dict = compute_mce(corruption_accs)
  print_ces(ce_dict)
  print('mCE (normalized by AlexNet): {:.2f}'.format(mce))
  with open(log_file, 'a') as f:
    f.write('individual corruption errors: \n')
    for per in CORRUPTIONS:
        f.write('{0}: {ce:.2f}\n'.format(per, ce=ce_dict[per]))
    f.write('mCE: {:.2f}\t'.format(mce))
    f.write('\n')