in train_vclr.py [0:0]
def train_vclr(epoch, train_loader, model, model_ema, contrast, contrast_tsn, criterion, optimizer, scheduler, writer, args):
model.train()
set_bn_train(model_ema)
batch_time = AverageMeter()
loss_meter = AverageMeter()
timer = mmcv.Timer()
for idx, (xq, x1, x2, x3, order_label, tsn_q, tsn_k) in enumerate(train_loader):
xq = xq.cuda(non_blocking=True) # query
x1 = x1.cuda(non_blocking=True) # same frame diff aug
x2 = x2.cuda(non_blocking=True) # diff frame 1
x3 = x3.cuda(non_blocking=True) # diff frame 2
order_label = order_label.cuda(non_blocking=True)
tsn_q = tsn_q.cuda(non_blocking=True)
tsn_k = tsn_k.cuda(non_blocking=True)
# forward keys
with torch.no_grad():
x1_shuffled, x1_backward_inds = DistributedShuffle.forward_shuffle(x1)
x2_shuffled, x2_backward_inds = DistributedShuffle.forward_shuffle(x2)
x3_shuffled, x3_backward_inds = DistributedShuffle.forward_shuffle(x3)
x1_feat_inter, x1_feat_intra = model_ema(x1_shuffled)
x2_feat_inter, x2_feat_intra = model_ema(x2_shuffled)
x3_feat_inter, x3_feat_intra = model_ema(x3_shuffled)
x1_feat_inter_all, x1_feat_inter = DistributedShuffle.backward_shuffle(x1_feat_inter, x1_backward_inds)
x1_feat_intra_all, x1_feat_intra = DistributedShuffle.backward_shuffle(x1_feat_intra, x1_backward_inds)
x2_feat_inter_all, x2_feat_inter = DistributedShuffle.backward_shuffle(x2_feat_inter, x2_backward_inds)
x2_feat_intra_all, x2_feat_intra = DistributedShuffle.backward_shuffle(x2_feat_intra, x2_backward_inds)
x3_feat_inter_all, x3_feat_inter = DistributedShuffle.backward_shuffle(x3_feat_inter, x3_backward_inds)
x3_feat_intra_all, x3_feat_intra = DistributedShuffle.backward_shuffle(x3_feat_intra, x3_backward_inds)
# tsn, o3n
tsn_k_shuffle, tsn_k_backward_inds = DistributedShuffle.forward_shuffle(tsn_k)
tsn_k_feat, o3n_k = model_ema(tsn_k_shuffle, tsn_mode=True)
tsn_k_feat_all, tsn_k_feat = DistributedShuffle.backward_shuffle(tsn_k_feat, tsn_k_backward_inds)
o3n_k_feat_all, o3n_k_feat = DistributedShuffle.backward_shuffle(o3n_k, tsn_k_backward_inds)
# forward query
xq_feat_inter, xq_feat_intra = model(xq)
tsn_q_feat, o3n_q_feat, xq_logit_order = model(tsn_q, order_feat=o3n_k_feat, tsn_mode=True)
out_inter = contrast(xq_feat_inter,
x1_feat_inter, x2_feat_inter, x3_feat_inter,
torch.cat([x1_feat_inter_all, x2_feat_inter_all, x3_feat_inter_all], dim=0), inter=True)
out_intra = contrast(xq_feat_intra,
x1_feat_intra, x2_feat_intra, x3_feat_intra, None, inter=False)
out_tsn = contrast_tsn(tsn_q_feat,
tsn_k_feat, tsn_k_feat_all)
# loss calc
loss_inter = criterion(out_inter)
loss_intra = criterion(out_intra)
loss_order = torch.nn.functional.cross_entropy(xq_logit_order, order_label)
loss_tsn = criterion(out_tsn)
loss = loss_inter + loss_intra + loss_order + loss_tsn
# backward
optimizer.zero_grad()
loss.backward()
# update params
optimizer.step()
scheduler.step()
moment_update(model, model_ema, args.alpha)
# update meters
loss_meter.update(loss.item())
batch_time.update(timer.since_last_check())
# print info
if idx % args.print_freq == 0:
logger.info(
'Train: [{:>3d}]/[{:>4d}/{:>4d}] BT={:>0.3f}/{:>0.3f} Loss={:>0.3f} {:>0.3f} {:>0.3f} {:>0.3f} {:>0.3f}/{:>0.3f}'.format(
epoch, idx, len(train_loader),
batch_time.val, batch_time.avg,
loss.item(), loss_inter.item(), loss_intra.item(), loss_order.item(), loss_tsn.item(), loss_meter.avg,
))
# summary to tensorboard
if dist.get_rank() == 0:
n_iter = idx + len(train_loader) * (epoch - 1)
writer.add_scalar('Loss/loss', loss.item(), n_iter)
writer.add_scalar('Loss/loss_avg', loss_meter.avg, n_iter)
writer.add_scalar('Loss/loss_inter', loss_inter.item(), n_iter)
writer.add_scalar('Loss/loss_intra', loss_intra.item(), n_iter)
writer.add_scalar('Loss/loss_order', loss_order.item(), n_iter)
writer.add_scalar('Loss/loss_tsn', loss_tsn.item(), n_iter)
currlr = 0.0
for param_group in optimizer.param_groups:
currlr = param_group['lr']
break
writer.add_scalar('lr', currlr, n_iter)
return loss_meter.avg