def train()

in main.py [0:0]


def train(train_loader, model, optimizer, epoch, writer, selflabels):

    global sk_schedule
    global sk_counter


    # Put model in train mode
    model.train()

    # Init Logger meters
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    world_size = args.world_size
    dataset_bs = train_loader.batch_size

    end = time.time()
    batches_thusfar = epoch * len(train_loader)
    for it, inputs in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # ============ Get inputs ... ============
        video, audio, _, selected, _ = inputs
        video, audio = video.cuda(), audio.cuda()

        # ============ Occasional clustering via Sinkhorn-Knopp ... ===========
        if batches_thusfar + it >= sk_schedule[-1]:
            # optimize labels
            with torch.no_grad():
                _ = sk_schedule.pop()
                selflabels = cluster(
                    args, selflabels, train_loader.dataset, model, sk_counter,
                    logger, writer, group,
                    (batches_thusfar + it) * dataset_bs * world_size

                )

        # ============ forward passes ... ============
        feat_v, feat_a = model(video, audio)

        # ============ SeLaVi loss ... ============
        if args.headcount == 1:
            labels = selflabels[selected, 0]
        else:
            labels = selflabels[selected, :]
        loss_vid = get_loss(feat_v, labels, headcount=args.headcount)
        loss_aud = get_loss(feat_a, labels, headcount=args.headcount)
        loss = 0.5 * loss_vid + 0.5 * loss_aud

        # ============ backward and optim step ... ============
        optimizer.zero_grad()
        if args.use_fp16:
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        # ============ misc ... ============
        losses.update(loss.item(), inputs[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        iteration = epoch * len(train_loader) + it
        if args.rank == 0 and it % 50 == 0:
            logger.info(
                "Epoch: [{0}][{1}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "Lr: {lr:.4f}".format(
                    epoch,
                    it,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    lr=optimizer.param_groups[0]["lr"],
                )
            )

            # Log onto tensorboard
            if writer:
                writer.add_scalar(
                    f'loss/iter', loss.item(), iteration)
                writer.add_scalar(
                    f'lr/iter', optimizer.param_groups[0]["lr"], iteration)
                writer.add_scalar(
                    f'batch_time/iter', batch_time.avg, iteration)
                writer.add_scalar(
                    f'data_time/iter', data_time.avg, iteration)

        # ============ signal handling ... ============
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            if args.rank == 0:
                logger.info("Beginning reqeue")
                trigger_job_requeue(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"))

    dist.barrier()
    torch.cuda.empty_cache()
    return (epoch, losses.avg), selflabels