in tools/train_net_shapenet.py [0:0]
def training_loop(cfg, cp, model, optimizer, scheduler, loaders, device, loss_fn):
Timer.timing = False
iteration_timer = Timer("Iteration")
# model.parameters() is surprisingly expensive at 150ms, so cache it
if hasattr(model, "module"):
params = list(model.module.parameters())
else:
params = list(model.parameters())
loss_moving_average = cp.data.get("loss_moving_average", None)
while cp.epoch < cfg.SOLVER.NUM_EPOCHS:
if comm.is_main_process():
logger.info("Starting epoch %d / %d" % (cp.epoch + 1, cfg.SOLVER.NUM_EPOCHS))
# When using a DistributedSampler we need to manually set the epoch so that
# the data is shuffled differently at each epoch
for loader in loaders.values():
if hasattr(loader.sampler, "set_epoch"):
loader.sampler.set_epoch(cp.epoch)
for i, batch in enumerate(loaders["train"]):
if i == 0:
iteration_timer.start()
else:
iteration_timer.tick()
batch = loaders["train"].postprocess(batch, device)
imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch
num_infinite_params = 0
for p in params:
num_infinite_params += (torch.isfinite(p.data) == 0).sum().item()
if num_infinite_params > 0:
msg = "ERROR: Model has %d non-finite params (before forward!)"
logger.info(msg % num_infinite_params)
return
model_kwargs = {}
if cfg.MODEL.VOXEL_ON and cp.t < cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS:
model_kwargs["voxel_only"] = True
with Timer("Forward"):
voxel_scores, meshes_pred = model(imgs, **model_kwargs)
num_infinite = 0
for cur_meshes in meshes_pred:
cur_verts = cur_meshes.verts_packed()
num_infinite += (torch.isfinite(cur_verts) == 0).sum().item()
if num_infinite > 0:
logger.info("ERROR: Got %d non-finite verts" % num_infinite)
return
loss, losses = None, {}
if num_infinite == 0:
loss, losses = loss_fn(
voxel_scores, meshes_pred, voxels_gt, (points_gt, normals_gt)
)
skip = loss is None
if loss is None or (torch.isfinite(loss) == 0).sum().item() > 0:
logger.info("WARNING: Got non-finite loss %f" % loss)
skip = True
if model_kwargs.get("voxel_only", False):
for k, v in losses.items():
if k != "voxel":
losses[k] = 0.0 * v
if loss is not None and cp.t % cfg.SOLVER.LOGGING_PERIOD == 0:
if comm.is_main_process():
cp.store_metric(loss=loss.item())
str_out = "Iteration: %d, epoch: %d, lr: %.5f," % (
cp.t,
cp.epoch,
optimizer.param_groups[0]["lr"],
)
for k, v in losses.items():
str_out += " %s loss: %.4f," % (k, v.item())
str_out += " total loss: %.4f," % loss.item()
# memory allocaged
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
str_out += " mem: %d" % max_mem_mb
if len(meshes_pred) > 0:
mean_V = meshes_pred[-1].num_verts_per_mesh().float().mean().item()
mean_F = meshes_pred[-1].num_faces_per_mesh().float().mean().item()
str_out += ", mesh size = (%d, %d)" % (mean_V, mean_F)
logger.info(str_out)
if loss_moving_average is None and loss is not None:
loss_moving_average = loss.item()
# Skip backprop for this batch if the loss is above the skip factor times
# the moving average for losses
if loss is None:
pass
elif loss.item() > cfg.SOLVER.SKIP_LOSS_THRESH * loss_moving_average:
logger.info("Warning: Skipping loss %f on GPU %d" % (loss.item(), comm.get_rank()))
cp.store_metric(losses_skipped=loss.item())
skip = True
else:
# Update the moving average of our loss
gamma = cfg.SOLVER.LOSS_SKIP_GAMMA
loss_moving_average *= gamma
loss_moving_average += (1.0 - gamma) * loss.item()
cp.store_data("loss_moving_average", loss_moving_average)
if skip:
logger.info("Dummy backprop on GPU %d" % comm.get_rank())
loss = 0.0 * sum(p.sum() for p in params)
# Backprop and step
scheduler.step()
optimizer.zero_grad()
with Timer("Backward"):
loss.backward()
# When training with normal loss, sometimes I get NaNs in gradient that
# cause the model to explode. Check for this before performing a gradient
# update. This is safe in mult-GPU since gradients have already been
# summed, so each GPU has the same gradients.
num_infinite_grad = 0
for p in params:
num_infinite_grad += (torch.isfinite(p.grad) == 0).sum().item()
if num_infinite_grad == 0:
optimizer.step()
else:
msg = "WARNING: Got %d non-finite elements in gradient; skipping update"
logger.info(msg % num_infinite_grad)
cp.step()
if cp.t % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
eval_and_save(model, loaders, optimizer, scheduler, cp)
cp.step_epoch()
eval_and_save(model, loaders, optimizer, scheduler, cp)
if comm.is_main_process():
logger.info("Evaluating on test set:")
test_loader = build_data_loader(cfg, "MeshVox", "test", multigpu=False)
evaluate_test(model, test_loader)