def train_cn_image_consist()

in imagenet.py [0:0]


def train_cn_image_consist(model, train_loader, optimizer):
    """Train for one epoch."""
    print('running train_cn_image_consist')
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    s_losses = AverageMeter()
    c_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.train()
    end = time.time()
    # make sure using crop because the two image augmentations should be different
    assert args.beta is not None
    assert args.crop in ['both', 'style', 'content']
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda()

        # compute output
        r = np.random.rand(1)
        if r < args.cn_prob:
            # print('\ncross norm training...')
            # print('computing logits_clean')
            logits_clean = model(input, aug=False)
            # Cross-entropy is only computed on clean images
            loss = F.cross_entropy(logits_clean, target)

            # # print('computing logits_aug1')
            input_aug1 = cn_op_2ins_space_chan(input, beta=args.beta, crop=args.crop)
            logits_aug1 = model(input_aug1, aug=False)

            # # print('computing logits_aug2')
            input_aug2 = cn_op_2ins_space_chan(input, beta=args.beta, crop=args.crop)
            logits_aug2 = model(input_aug2, aug=False)
            #
            p_clean, p_aug1, p_aug2 = F.softmax(
                logits_clean, dim=1), F.softmax(
                logits_aug1, dim=1), F.softmax(
                logits_aug2, dim=1)

            # Clamp mixture distribution to avoid exploding KL divergence
            p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
            consist_loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                            F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                            F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

            s_losses.update(loss.item(), input.size(0))
            c_losses.update(consist_loss.item(), input.size(0))
            loss += args.consist_wt * consist_loss
            losses.update(loss.item(), input.size(0))
        else:
            # print('\nbasic training...')
            logits_clean = model(input, aug=False)
            loss = F.cross_entropy(logits_clean, target)
            s_losses.update(loss.item(), input.size(0))

        # measure accuracy and record loss
        err1, err5 = error(logits_clean, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(err1.item(), input.size(0))
        top5.update(err5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            # print('Train Loss {:.3f}'.format(loss_ema))
            print('Iter: [{0}/{1}]\t'
                  'Supervised Loss {s_losses.val:.4f} ({s_losses.avg:.4f})\t'
                  'Consistency Loss {c_losses.val:.4f} ({c_losses.avg:.4f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i, len(train_loader),
                   s_losses=s_losses, c_losses=c_losses, loss=losses))

    return top1.avg