def train()

in scripts/train_imagenet.py [0:0]


def train(train_loader, model, criterion, optimizer, scheduler, epoch):
    global logger, conf, tb
    batch_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(epoch)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        if conf["optimizer"]["schedule"]["mode"] == "step":
            scheduler.step(i + epoch * len(train_loader))

        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if conf["optimizer"]["clip"] != 0.0:
            nn.utils.clip_grad_norm(model.parameters(), conf["optimizer"]["clip"])
        optimizer.step()

        # measure accuracy and record loss
        with torch.no_grad():
            output = output.detach()
            loss = loss.detach() * target.shape[0]
            prec1, prec5 = utils.accuracy_sum(output, target, topk=(1, 5))
            count = target.new_tensor([target.shape[0]], dtype=torch.long)
            if dist.is_initialized():
                dist.all_reduce(count, dist.ReduceOp.SUM)
            for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
                if dist.is_initialized():
                    dist.all_reduce(val, dist.ReduceOp.SUM)
                val /= count.item()
                meter.update(val.item(), count.item())

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logger.info(
                "Epoch: [{0}][{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f}) \t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f}) \t"
                "Loss {loss.val:.4f} ({loss.avg:.4f}) \t"
                "Prec@1 {top1.val:.3f} ({top1.avg:.3f}) \t"
                "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
                    epoch,
                    i,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    top1=top1,
                    top5=top5,
                )
            )

        if not dist.is_initialized() or dist.get_rank() == 0:
            tb.add_scalar("train/loss", losses.val, i + epoch * len(train_loader))
            tb.add_scalar(
                "train/lr", scheduler.get_lr()[0], i + epoch * len(train_loader)
            )
            tb.add_scalar("train/top1", top1.val, i + epoch * len(train_loader))
            tb.add_scalar("train/top5", top5.val, i + epoch * len(train_loader))
            if args.log_hist and i % 10 == 0:
                for name, param in model.named_parameters():
                    if name.find("fc") != -1 or name.find("bn_out") != -1:
                        tb.add_histogram(
                            name,
                            param.clone().cpu().data.numpy(),
                            i + epoch * len(train_loader),
                        )