in meshrcnn/modeling/roi_heads/roi_heads.py [0:0]
def _forward_shape(self, features, instances):
"""
Forward logic for the voxel and mesh refinement branch.
Args:
features (list[Tensor]): #level input features for voxel prediction
instances (list[Instances]): the per-image instances to train/predict meshes.
In training, they can be the proposals.
In inference, they can be the predicted boxes.
Returns:
In training, a dict of losses.
In inference, update `instances` with new fields "pred_voxels" & "pred_meshes" and return it.
"""
if not self.voxel_on and not self.mesh_on:
return {} if self.training else instances
features = [features[f] for f in self.in_features]
if self.training:
# The loss is only defined on positive proposals.
proposals, _ = select_foreground_proposals(instances, self.num_classes)
proposal_boxes = [x.proposal_boxes for x in proposals]
losses = {}
if self.voxel_on:
voxel_features = self.voxel_pooler(features, proposal_boxes)
voxel_logits = self.voxel_head(voxel_features)
loss_voxel, target_voxels = voxel_rcnn_loss(
voxel_logits, proposals, loss_weight=self.voxel_loss_weight
)
losses.update({"loss_voxel": loss_voxel})
if self._vis:
self._misc["target_voxels"] = target_voxels
if self.cls_agnostic_voxel:
with torch.no_grad():
vox_in = voxel_logits.sigmoid().squeeze(1) # (N, V, V, V)
init_mesh = cubify(vox_in, self.cubify_thresh) # 1
else:
raise ValueError("No support for class specific predictions")
if self.mesh_on:
mesh_features = self.mesh_pooler(features, proposal_boxes)
if not self.voxel_on:
if mesh_features.shape[0] > 0:
init_mesh = ico_sphere(self.ico_sphere_level, mesh_features.device)
init_mesh = init_mesh.extend(mesh_features.shape[0])
else:
init_mesh = Meshes(verts=[], faces=[])
pred_meshes = self.mesh_head(mesh_features, init_mesh)
# loss weights
loss_weights = {
"chamfer": self.chamfer_loss_weight,
"normals": self.normals_loss_weight,
"edge": self.edge_loss_weight,
}
if not pred_meshes[0].isempty():
loss_chamfer, loss_normals, loss_edge, target_meshes = mesh_rcnn_loss(
pred_meshes,
proposals,
loss_weights=loss_weights,
gt_num_samples=self.gt_num_samples,
pred_num_samples=self.pred_num_samples,
gt_coord_thresh=self.gt_coord_thresh,
)
if self._vis:
self._misc["init_meshes"] = init_mesh
self._misc["target_meshes"] = target_meshes
else:
loss_chamfer = sum(k.sum() for k in self.mesh_head.parameters()) * 0.0
loss_normals = sum(k.sum() for k in self.mesh_head.parameters()) * 0.0
loss_edge = sum(k.sum() for k in self.mesh_head.parameters()) * 0.0
losses.update(
{
"loss_chamfer": loss_chamfer,
"loss_normals": loss_normals,
"loss_edge": loss_edge,
}
)
return losses
else:
pred_boxes = [x.pred_boxes for x in instances]
if self.voxel_on:
voxel_features = self.voxel_pooler(features, pred_boxes)
voxel_logits = self.voxel_head(voxel_features)
voxel_rcnn_inference(voxel_logits, instances)
if self.cls_agnostic_voxel:
with torch.no_grad():
vox_in = voxel_logits.sigmoid().squeeze(1) # (N, V, V, V)
init_mesh = cubify(vox_in, self.cubify_thresh) # 1
else:
raise ValueError("No support for class specific predictions")
if self.mesh_on:
mesh_features = self.mesh_pooler(features, pred_boxes)
if not self.voxel_on:
if mesh_features.shape[0] > 0:
init_mesh = ico_sphere(self.ico_sphere_level, mesh_features.device)
init_mesh = init_mesh.extend(mesh_features.shape[0])
else:
init_mesh = Meshes(verts=[], faces=[])
pred_meshes = self.mesh_head(mesh_features, init_mesh)
mesh_rcnn_inference(pred_meshes[-1], instances)
else:
assert self.voxel_on
mesh_rcnn_inference(init_mesh, instances)
return instances