def train()

in scripts/train_detection.py [0:0]


def train(model, optimizer, scheduler, dataloader, meters, **varargs):
    model.train()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    optimizer.zero_grad()
    global_step = varargs["global_step"]
    loss_weights = varargs["loss_weights"]

    data_time_meter = AverageMeter((), meters["loss"].momentum)
    batch_time_meter = AverageMeter((), meters["loss"].momentum)

    data_time = time.time()
    for it, batch in enumerate(dataloader):
        # Upload batch
        batch = {k: batch[k].cuda(device=varargs["device"], non_blocking=True) for k in NETWORK_INPUTS}

        data_time_meter.update(torch.tensor(time.time() - data_time))

        # Update scheduler
        global_step += 1
        if varargs["batch_update"]:
            scheduler.step(global_step)

        batch_time = time.time()

        # Run network
        losses, _ = model(**batch, do_loss=True, do_prediction=False)
        distributed.barrier()

        losses = OrderedDict((k, v.mean()) for k, v in losses.items())
        losses["loss"] = sum(w * l for w, l in zip(loss_weights, losses.values()))

        optimizer.zero_grad()
        losses["loss"].backward()
        optimizer.step()

        # Gather stats from all workers
        losses = all_reduce_losses(losses)

        # Update meters
        with torch.no_grad():
            for loss_name, loss_value in losses.items():
                meters[loss_name].update(loss_value.cpu())
        batch_time_meter.update(torch.tensor(time.time() - batch_time))

        # Clean-up
        del batch, losses

        # Log
        if varargs["summary"] is not None and (it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                varargs["summary"], "train", global_step,
                varargs["epoch"] + 1, varargs["num_epochs"],
                it + 1, len(dataloader),
                OrderedDict([
                    ("lr", scheduler.get_lr()[0]),
                    ("loss", meters["loss"]),
                    ("obj_loss", meters["obj_loss"]),
                    ("bbx_loss", meters["bbx_loss"]),
                    ("roi_cls_loss", meters["roi_cls_loss"]),
                    ("roi_bbx_loss", meters["roi_bbx_loss"]),
                    ("data_time", data_time_meter),
                    ("batch_time", batch_time_meter)
                ])
            )

        data_time = time.time()

    return global_step