def train_vclr()

in train_vclr.py [0:0]


def train_vclr(epoch, train_loader, model, model_ema, contrast, contrast_tsn, criterion, optimizer, scheduler, writer, args):
    model.train()
    set_bn_train(model_ema)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    timer = mmcv.Timer()
    for idx, (xq, x1, x2, x3, order_label, tsn_q, tsn_k) in enumerate(train_loader):
        xq = xq.cuda(non_blocking=True)  # query
        x1 = x1.cuda(non_blocking=True)  # same frame diff aug
        x2 = x2.cuda(non_blocking=True)  # diff frame 1
        x3 = x3.cuda(non_blocking=True)  # diff frame 2
        order_label = order_label.cuda(non_blocking=True)
        tsn_q = tsn_q.cuda(non_blocking=True)
        tsn_k = tsn_k.cuda(non_blocking=True)
        # forward keys
        with torch.no_grad():
            x1_shuffled, x1_backward_inds = DistributedShuffle.forward_shuffle(x1)
            x2_shuffled, x2_backward_inds = DistributedShuffle.forward_shuffle(x2)
            x3_shuffled, x3_backward_inds = DistributedShuffle.forward_shuffle(x3)
            x1_feat_inter, x1_feat_intra = model_ema(x1_shuffled)
            x2_feat_inter, x2_feat_intra = model_ema(x2_shuffled)
            x3_feat_inter, x3_feat_intra = model_ema(x3_shuffled)
            x1_feat_inter_all, x1_feat_inter = DistributedShuffle.backward_shuffle(x1_feat_inter, x1_backward_inds)
            x1_feat_intra_all, x1_feat_intra = DistributedShuffle.backward_shuffle(x1_feat_intra, x1_backward_inds)
            x2_feat_inter_all, x2_feat_inter = DistributedShuffle.backward_shuffle(x2_feat_inter, x2_backward_inds)
            x2_feat_intra_all, x2_feat_intra = DistributedShuffle.backward_shuffle(x2_feat_intra, x2_backward_inds)
            x3_feat_inter_all, x3_feat_inter = DistributedShuffle.backward_shuffle(x3_feat_inter, x3_backward_inds)
            x3_feat_intra_all, x3_feat_intra = DistributedShuffle.backward_shuffle(x3_feat_intra, x3_backward_inds)

            # tsn, o3n
            tsn_k_shuffle, tsn_k_backward_inds = DistributedShuffle.forward_shuffle(tsn_k)
            tsn_k_feat, o3n_k = model_ema(tsn_k_shuffle, tsn_mode=True)
            tsn_k_feat_all, tsn_k_feat = DistributedShuffle.backward_shuffle(tsn_k_feat, tsn_k_backward_inds)
            o3n_k_feat_all, o3n_k_feat = DistributedShuffle.backward_shuffle(o3n_k, tsn_k_backward_inds)

        # forward query
        xq_feat_inter, xq_feat_intra = model(xq)
        tsn_q_feat, o3n_q_feat, xq_logit_order = model(tsn_q, order_feat=o3n_k_feat, tsn_mode=True)

        out_inter = contrast(xq_feat_inter,
                             x1_feat_inter, x2_feat_inter, x3_feat_inter,
                             torch.cat([x1_feat_inter_all, x2_feat_inter_all, x3_feat_inter_all], dim=0), inter=True)
        out_intra = contrast(xq_feat_intra,
                             x1_feat_intra, x2_feat_intra, x3_feat_intra, None, inter=False)
        out_tsn = contrast_tsn(tsn_q_feat,
                               tsn_k_feat, tsn_k_feat_all)

        # loss calc
        loss_inter = criterion(out_inter)
        loss_intra = criterion(out_intra)
        loss_order = torch.nn.functional.cross_entropy(xq_logit_order, order_label)
        loss_tsn = criterion(out_tsn)
        loss = loss_inter + loss_intra + loss_order + loss_tsn
        # backward
        optimizer.zero_grad()
        loss.backward()
        # update params
        optimizer.step()
        scheduler.step()
        moment_update(model, model_ema, args.alpha)
        # update meters
        loss_meter.update(loss.item())
        batch_time.update(timer.since_last_check())
        # print info
        if idx % args.print_freq == 0:
            logger.info(
                'Train: [{:>3d}]/[{:>4d}/{:>4d}] BT={:>0.3f}/{:>0.3f} Loss={:>0.3f} {:>0.3f} {:>0.3f} {:>0.3f} {:>0.3f}/{:>0.3f}'.format(
                    epoch, idx, len(train_loader),
                    batch_time.val, batch_time.avg,
                    loss.item(), loss_inter.item(), loss_intra.item(), loss_order.item(), loss_tsn.item(), loss_meter.avg,
                ))

        # summary to tensorboard
        if dist.get_rank() == 0:
            n_iter = idx + len(train_loader) * (epoch - 1)
            writer.add_scalar('Loss/loss', loss.item(), n_iter)
            writer.add_scalar('Loss/loss_avg', loss_meter.avg, n_iter)
            writer.add_scalar('Loss/loss_inter', loss_inter.item(), n_iter)
            writer.add_scalar('Loss/loss_intra', loss_intra.item(), n_iter)
            writer.add_scalar('Loss/loss_order', loss_order.item(), n_iter)
            writer.add_scalar('Loss/loss_tsn', loss_tsn.item(), n_iter)

            currlr = 0.0
            for param_group in optimizer.param_groups:
                currlr = param_group['lr']
                break
            writer.add_scalar('lr', currlr, n_iter)

    return loss_meter.avg