in imagenet.py [0:0]
def train_cn_image_augmix(net, train_loader, optimizer):
"""Train for one epoch."""
print('running train_cn_image_augmix')
net.train()
losses = AverageMeter()
s_losses = AverageMeter()
c_losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for i, (images, targets) in enumerate(train_loader):
# Compute data loading time
# data_time = time.time() - end
images_all = torch.cat(images, 0)
images_all = images_all.cuda()
targets = targets.cuda()
r = np.random.rand(1)
if r < args.cn_prob:
images_all = cn_op_2ins_space_chan(images_all, beta=args.beta, crop=args.crop)
logits_all = net(images_all)
logits_clean, logits_aug1, logits_aug2 = torch.split(
logits_all, images[0].size(0))
# Cross-entropy is only computed on clean images
loss = F.cross_entropy(logits_clean, targets)
p_clean, p_aug1, p_aug2 = F.softmax(
logits_clean, dim=1), F.softmax(
logits_aug1, dim=1), F.softmax(
logits_aug2, dim=1)
# Clamp mixture distribution to avoid exploding KL divergence
p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
consist_loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
s_losses.update(loss.item(), images[0].size(0))
c_losses.update(consist_loss.item(), images[0].size(0))
loss += 12 * consist_loss
losses.update(loss.item(), images[0].size(0))
err1, err5 = error(logits_clean, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking
top1.update(err1.item(), images[0].size(0))
top5.update(err5.item(), images[0].size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Compute batch computation time and update moving averages.
batch_time = time.time() - end
end = time.time()
if i % args.print_freq == 0:
# print('Train Loss {:.3f}'.format(loss_ema))
print('Iter: [{0}/{1}]\t'
'Supervised Loss {s_losses.val:.4f} ({s_losses.avg:.4f})\t'
'Consistency Loss {c_losses.val:.4f} ({c_losses.avg:.4f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i, len(train_loader),
s_losses=s_losses, c_losses=c_losses, loss=losses))
# if i == 10:
# break
return top1.avg