def train()

in main_stica.py [0:0]


def train(train_loader, model, optimizer, epoch, lr_schedule, writer):
    # Put model in train mode
    model.train()
    XE = torch.nn.CrossEntropyLoss()
    # Init Logger meters
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    croplosses_meter = AverageMeter()
    avlosses = AverageMeter()

    end = time.time()
    for it, inputs in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # update learning rate
        iteration = epoch * len(train_loader) + it
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_schedule[iteration]

        # get inputs
        video, audio, _, _, _ = inputs
        audio = audio.cuda(non_blocking=True)

        feat_v_nce_lst = []
        crop_feat_v_nces_lst = []
        feat_a_nce_lst = []

        # FORWARD PASSES
        for i in range(len(video)):
            # get video
            video_input = torch.cat(video[i: i+1]).cuda(non_blocking=True)

            # get crop params
            params = fmcrop_params(
                duration=model.module.duration,
                s_large_crops=args.num_large_crops,  
                s_small_crops=args.num_small_crops,
                t_large_crops=args.num_large_tcrops, 
                t_small_crops=args.num_small_tcrops
            )

            # Forward pass
            feat_v_nce, crop_feat_v_nces, feat_a_nce = model(
                video_input, audio, params=params)

            # Save features
            feat_v_nce_lst.append(feat_v_nce)
            crop_feat_v_nces_lst.append(crop_feat_v_nces)
            feat_a_nce_lst.append(feat_a_nce)

        # CROP & LOSS COMPUTATION
        crop_losses, counters = nce_crop_losses_dual(
            feats_v=crop_feat_v_nces_lst[0],
            feats_v2=crop_feat_v_nces_lst[1], 
            XE=XE,
            s_large_crops=args.num_large_crops,  
            s_small_crops=args.num_small_crops,
            t_large_crops=args.num_large_tcrops, 
            t_small_crops=args.num_small_tcrops, 
            temp=args.temp
        )
        loss_crops = sum(crop_losses) / sum(counters)
        if args.cross_modal_alpha > 0:
            loss_av = 0.5 * (
                gdt_loss(feat_v_nce_lst[0], feat_a_nce_lst[0], XE) +
                gdt_loss(feat_v_nce_lst[1], feat_a_nce_lst[1], XE)
            )
        else:
            loss_av = torch.tensor(0)

        if args.cross_modal_alpha > 0:
            loss = (
                (1. - args.cross_modal_alpha) * loss_crops + 
                args.cross_modal_alpha * loss_av
            )
        else:
            loss = (1. - args.cross_modal_alpha) * loss_crops

        # BACKWARD AND OPTIMIZER STEP
        optimizer.zero_grad()
        if args.use_fp16:
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        # LOGGING
        bs = audio.size(0)
        losses.update(loss.item(), bs)
        avlosses.update(loss_av.item(), bs)
        croplosses_meter.update(loss_crops.item(), bs)
        batch_time.update(time.time() - end)
        end = time.time()
        if args.rank == 0 and it % 50 == 0:
            logger.info(
                "Epoch: [{0}][{1}]\t"
                "Time {batch_time.val:.2f} ({batch_time.avg:.2f})\t"
                "Data {data_time.val:.2f} ({data_time.avg:.2f})\t"
                "Loss {loss.val:.2f} ({loss.avg:.2f})\t"
                "AVLoss {avloss.val:.2f} ({avloss.avg:.2f})\t"
                "CropLoss {closs.val:.2f} ({closs.avg:.2f})\t"
                "Lr: {lr:.4f}".format(
                    epoch,
                    it,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    avloss=avlosses,
                    closs=croplosses_meter,
                    lr=optimizer.param_groups[0]["lr"],
                )
            )


            # Log onto tensorboard
            if writer:
                log_scalars(writer, loss, loss_crops, crop_losses, loss_av,
                            counters, optimizer, batch_time, data_time, iteration)

        # ============ signal handling ... ============
        if os.environ['SIGNAL_RECEIVED'] == 'True':
            if args.rank == 0:
                logger.info("Beginning reqeue")
                trigger_job_requeue(os.path.join(
                    args.dump_path, 
                    "checkpoint.pth.tar"
                ))

    return epoch, losses.avg