def train()

in main_fixmatch.py [0:0]


def train(train_loader_x, train_loader_u, model, optimizer, epoch, args):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    data_time = utils.AverageMeter('Data', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    losses_x = utils.AverageMeter('Loss_x', ':.4e')
    losses_u = utils.AverageMeter('Loss_u', ':.4e')
    top1_x = utils.AverageMeter('Acc_x@1', ':6.2f')
    top5_x = utils.AverageMeter('Acc_x@5', ':6.2f')
    top1_u = utils.AverageMeter('Acc_u@1', ':6.2f')
    top5_u = utils.AverageMeter('Acc_u@5', ':6.2f')
    mask_info = utils.AverageMeter('Mask', ':6.3f')
    curr_lr = utils.InstantMeter('LR', '')
    progress = utils.ProgressMeter(
        len(train_loader_u),
        [curr_lr, batch_time, data_time, losses, losses_x, losses_u, top1_x, top5_x, top1_u, top5_u, mask_info],
        prefix="Epoch: [{}/{}]\t".format(epoch, args.epochs))

    epoch_x = epoch * math.ceil(len(train_loader_u) / len(train_loader_x))
    if args.distributed:
        print("set epoch={} for labeled sampler".format(epoch_x))
        train_loader_x.sampler.set_epoch(epoch_x)
        print("set epoch={} for unlabeled sampler".format(epoch))
        train_loader_u.sampler.set_epoch(epoch)

    train_iter_x = iter(train_loader_x)
    # switch to train mode
    model.train()
    if args.eman:
        print("setting the ema model to eval mode")
        if hasattr(model, 'module'):
            model.module.ema.eval()
        else:
            model.ema.eval()

    end = time.time()
    for i, (images_u, targets_u) in enumerate(train_loader_u):
        try:
            images_x, targets_x = next(train_iter_x)
        except Exception:
            epoch_x += 1
            print("reshuffle train_loader_x at epoch={}".format(epoch_x))
            if args.distributed:
                print("set epoch={} for labeled sampler".format(epoch_x))
                train_loader_x.sampler.set_epoch(epoch_x)
            train_iter_x = iter(train_loader_x)
            images_x, targets_x = next(train_iter_x)

        images_u_w, images_u_s = images_u
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images_x = images_x.cuda(args.gpu, non_blocking=True)
            images_u_w = images_u_w.cuda(args.gpu, non_blocking=True)
            images_u_s = images_u_s.cuda(args.gpu, non_blocking=True)
        targets_x = targets_x.cuda(args.gpu, non_blocking=True)
        targets_u = targets_u.cuda(args.gpu, non_blocking=True)

        # warmup learning rate
        if epoch < args.warmup_epoch:
            warmup_step = args.warmup_epoch * len(train_loader_u)
            curr_step = epoch * len(train_loader_u) + i + 1
            lr_schedule.warmup_learning_rate(optimizer, curr_step, warmup_step, args)
        curr_lr.update(optimizer.param_groups[0]['lr'])

        # model forward
        logits_x, logits_u_w, logits_u_s = model(images_x, images_u_w, images_u_s)
        # pseudo label
        pseudo_label = torch.softmax(logits_u_w.detach_(), dim=-1)
        max_probs, pseudo_targets_u = torch.max(pseudo_label, dim=-1)
        mask = max_probs.ge(args.threshold).float()

        # compute losses
        loss_x = F.cross_entropy(logits_x, targets_x, reduction='mean')
        loss_u = (F.cross_entropy(logits_u_s, pseudo_targets_u, reduction='none') * mask).mean()
        loss = loss_x + args.lambda_u * loss_u

        # measure accuracy and record loss
        losses.update(loss.item())
        losses_x.update(loss_x.item(), images_x.size(0))
        losses_u.update(loss_u.item(), images_u_w.size(0))
        acc1_x, acc5_x = utils.accuracy(logits_x, targets_x, topk=(1, 5))
        top1_x.update(acc1_x[0], logits_x.size(0))
        top5_x.update(acc5_x[0], logits_x.size(0))
        acc1_u, acc5_u = utils.accuracy(logits_u_w, targets_u, topk=(1, 5))
        top1_u.update(acc1_u[0], logits_u_w.size(0))
        top5_u.update(acc5_u[0], logits_u_w.size(0))
        mask_info.update(mask.mean().item(), mask.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # update the ema model
        if args.eman:
            if hasattr(model, 'module'):
                model.module.momentum_update_ema()
            else:
                model.momentum_update_ema()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)