in scripts/train_detection.py [0:0]
def train(model, optimizer, scheduler, dataloader, meters, **varargs):
model.train()
dataloader.batch_sampler.set_epoch(varargs["epoch"])
optimizer.zero_grad()
global_step = varargs["global_step"]
loss_weights = varargs["loss_weights"]
data_time_meter = AverageMeter((), meters["loss"].momentum)
batch_time_meter = AverageMeter((), meters["loss"].momentum)
data_time = time.time()
for it, batch in enumerate(dataloader):
# Upload batch
batch = {k: batch[k].cuda(device=varargs["device"], non_blocking=True) for k in NETWORK_INPUTS}
data_time_meter.update(torch.tensor(time.time() - data_time))
# Update scheduler
global_step += 1
if varargs["batch_update"]:
scheduler.step(global_step)
batch_time = time.time()
# Run network
losses, _ = model(**batch, do_loss=True, do_prediction=False)
distributed.barrier()
losses = OrderedDict((k, v.mean()) for k, v in losses.items())
losses["loss"] = sum(w * l for w, l in zip(loss_weights, losses.values()))
optimizer.zero_grad()
losses["loss"].backward()
optimizer.step()
# Gather stats from all workers
losses = all_reduce_losses(losses)
# Update meters
with torch.no_grad():
for loss_name, loss_value in losses.items():
meters[loss_name].update(loss_value.cpu())
batch_time_meter.update(torch.tensor(time.time() - batch_time))
# Clean-up
del batch, losses
# Log
if varargs["summary"] is not None and (it + 1) % varargs["log_interval"] == 0:
logging.iteration(
varargs["summary"], "train", global_step,
varargs["epoch"] + 1, varargs["num_epochs"],
it + 1, len(dataloader),
OrderedDict([
("lr", scheduler.get_lr()[0]),
("loss", meters["loss"]),
("obj_loss", meters["obj_loss"]),
("bbx_loss", meters["bbx_loss"]),
("roi_cls_loss", meters["roi_cls_loss"]),
("roi_bbx_loss", meters["roi_bbx_loss"]),
("data_time", data_time_meter),
("batch_time", batch_time_meter)
])
)
data_time = time.time()
return global_step