def train_cn_augmix()

in cifar.py [0:0]


def train_cn_augmix(net, train_loader, optimizer, scheduler):
  """Train for one epoch."""
  print('running train_cn_augmix')
  s_losses = AverageMeter()
  c_losses = AverageMeter()
  losses = AverageMeter()
  net.train()
  loss_ema = 0.
  for i, (images, targets) in enumerate(train_loader):
    optimizer.zero_grad()

    # print('augmix forward...')
    images_all = torch.cat(images, 0).cuda()
    targets = targets.cuda()
    logits_all = net(images_all, aug=False)

    logits_clean, logits_aug1, logits_aug2 = torch.split(
        logits_all, images[0].size(0))

    # Cross-entropy is only computed on clean images
    loss = F.cross_entropy(logits_clean, targets)

    p_clean, p_aug1, p_aug2 = F.softmax(
        logits_clean, dim=1), F.softmax(
        logits_aug1, dim=1), F.softmax(
        logits_aug2, dim=1)

    # Clamp mixture distribution to avoid exploding KL divergence
    p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
    consist_loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
    s_losses.update(loss.item(), images[0].size(0))
    c_losses.update(consist_loss.item(), images[0].size(0))
    loss += 12 * consist_loss
    losses.update(loss.item(), images[0].size(0))

    r = np.random.rand(1)
    if r < args.cn_prob:
        logits_cn_aug1 = net(images[0], aug=True)
        logits_cn_aug2 = net(images[0], aug=True)
        #
        p_cn_aug1, p_cn_aug2 = F.softmax(
            logits_cn_aug1, dim=1), F.softmax(
            logits_cn_aug2, dim=1)
        p_cn_mixture = torch.clamp((p_clean + p_cn_aug1 + p_cn_aug2) / 3., 1e-7, 1).log()
        cn_consist_loss = (F.kl_div(p_cn_mixture, p_clean, reduction='batchmean') +
                           F.kl_div(p_cn_mixture, p_cn_aug1, reduction='batchmean') +
                           F.kl_div(p_cn_mixture, p_cn_aug2, reduction='batchmean')) / 3.
        loss += args.consist_wt * cn_consist_loss

    loss.backward()
    optimizer.step()
    scheduler.step()
    loss_ema = loss_ema * 0.9 + float(loss) * 0.1
    if i % args.print_freq == 0:
      print('Supervised Loss {s_losses.val:.4f} ({s_losses.avg:.4f})\t'
            'Consistency Loss {c_losses.val:.4f} ({c_losses.avg:.4f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
             s_losses=s_losses, c_losses=c_losses, loss=losses))

  return loss_ema