cifar.py [227:246]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    logits_clean, logits_aug1, logits_aug2 = torch.split(
        logits_all, images[0].size(0))

    # Cross-entropy is only computed on clean images
    loss = F.cross_entropy(logits_clean, targets)

    p_clean, p_aug1, p_aug2 = F.softmax(
        logits_clean, dim=1), F.softmax(
        logits_aug1, dim=1), F.softmax(
        logits_aug2, dim=1)

    # Clamp mixture distribution to avoid exploding KL divergence
    p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
    consist_loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
    s_losses.update(loss.item(), images[0].size(0))
    c_losses.update(consist_loss.item(), images[0].size(0))
    loss += 12 * consist_loss
    losses.update(loss.item(), images[0].size(0))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



imagenet.py [361:381]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    logits_clean, logits_aug1, logits_aug2 = torch.split(
        logits_all, images[0].size(0))

    # Cross-entropy is only computed on clean images
    loss = F.cross_entropy(logits_clean, targets)

    p_clean, p_aug1, p_aug2 = F.softmax(
        logits_clean, dim=1), F.softmax(
        logits_aug1, dim=1), F.softmax(
        logits_aug2, dim=1)

    # Clamp mixture distribution to avoid exploding KL divergence
    p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
    consist_loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

    s_losses.update(loss.item(), images[0].size(0))
    c_losses.update(consist_loss.item(), images[0].size(0))
    loss += 12 * consist_loss
    losses.update(loss.item(), images[0].size(0))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



