in tools/train_net_shapenet.py [0:0]
def main_worker(worker_id, args):
distributed = False
if args.num_gpus > 1:
distributed = True
dist.init_process_group(
backend="NCCL", init_method=args.dist_url, world_size=args.num_gpus, rank=worker_id
)
torch.cuda.set_device(worker_id)
device = torch.device("cuda:%d" % worker_id)
cfg = setup(args)
# data loaders
loaders = setup_loaders(cfg)
for split_name, loader in loaders.items():
logger.info("%s - %d" % (split_name, len(loader)))
# build the model
model = build_model(cfg)
model.to(device)
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[worker_id],
output_device=worker_id,
check_reduction=True,
broadcast_buffers=False,
)
optimizer = build_optimizer(cfg, model)
cfg.SOLVER.COMPUTED_MAX_ITERS = cfg.SOLVER.NUM_EPOCHS * len(loaders["train"])
scheduler = build_lr_scheduler(cfg, optimizer)
loss_fn_kwargs = {
"chamfer_weight": cfg.MODEL.MESH_HEAD.CHAMFER_LOSS_WEIGHT,
"normal_weight": cfg.MODEL.MESH_HEAD.NORMALS_LOSS_WEIGHT,
"edge_weight": cfg.MODEL.MESH_HEAD.EDGE_LOSS_WEIGHT,
"voxel_weight": cfg.MODEL.VOXEL_HEAD.LOSS_WEIGHT,
"gt_num_samples": cfg.MODEL.MESH_HEAD.GT_NUM_SAMPLES,
"pred_num_samples": cfg.MODEL.MESH_HEAD.PRED_NUM_SAMPLES,
}
loss_fn = MeshLoss(**loss_fn_kwargs)
checkpoint_path = "checkpoint.pt"
checkpoint_path = os.path.join(cfg.OUTPUT_DIR, checkpoint_path)
cp = Checkpoint(checkpoint_path)
if len(cp.restarts) == 0:
# We are starting from scratch, so store some initial data in cp
iter_per_epoch = len(loaders["train"])
cp.store_data("iter_per_epoch", iter_per_epoch)
else:
logger.info("Loading model state from checkpoint")
model.load_state_dict(cp.latest_states["model"])
optimizer.load_state_dict(cp.latest_states["optim"])
scheduler.load_state_dict(cp.latest_states["lr_scheduler"])
training_loop(cfg, cp, model, optimizer, scheduler, loaders, device, loss_fn)