in d2go/modeling/subclass.py [0:0]
def forward(self, images, features, proposals, targets=None):
"""
Same as StandardROIHeads.forward but add logic for subclass.
"""
if not self.subclass_on:
return super().forward(images, features, proposals, targets)
# --- start copy -------------------------------------------------------
del images
if self.training:
proposals = self.label_and_sample_proposals(proposals, targets)
# NOTE: `has_gt` = False for negatives and we must manually register `gt_subclasses`,
# because custom gt_* fields will not be automatically registered in sampled proposals.
for pp_per_im in proposals:
if not pp_per_im.has("gt_subclasses"):
background_subcls_idx = 0
pp_per_im.gt_subclasses = torch.cuda.LongTensor(
len(pp_per_im)
).fill_(background_subcls_idx)
del targets
features_list = [features[f] for f in self.in_features]
box_features = self.box_pooler(
features_list, [x.proposal_boxes for x in proposals]
)
box_features = self.box_head(box_features)
predictions = self.box_predictor(box_features)
# --- end copy ---------------------------------------------------------
# NOTE: don't delete box_features, keep it temporarily
# del box_features
box_features = box_features.view(
box_features.shape[0], np.prod(box_features.shape[1:])
)
pred_subclass_logits = self.subclass_head(box_features)
if self.training:
losses = self.box_predictor.losses(predictions, proposals)
# During training the proposals used by the box head are
# used by the mask, keypoint (and densepose) heads.
losses.update(self._forward_mask(features, proposals))
losses.update(self._forward_keypoint(features, proposals))
# subclass head
gt_subclasses = cat([p.gt_subclasses for p in proposals], dim=0)
loss_subclass = F.cross_entropy(
pred_subclass_logits, gt_subclasses, reduction="mean"
)
losses.update({"loss_subclass": loss_subclass})
return proposals, losses
else:
pred_instances, kept_indices = self.box_predictor.inference(
predictions, proposals
)
# During inference cascaded prediction is used: the mask and keypoints
# heads are only applied to the top scoring box detections.
pred_instances = self.forward_with_given_boxes(features, pred_instances)
# subclass head
probs = F.softmax(pred_subclass_logits, dim=-1)
for pred_instances_i, kept_indices_i in zip(pred_instances, kept_indices):
pred_instances_i.pred_subclass_prob = torch.index_select(
probs,
dim=0,
index=kept_indices_i.to(torch.int64),
)
if torch.onnx.is_in_onnx_export():
assert len(pred_instances) == 1
pred_instances[0].pred_subclass_prob = alias(
pred_instances[0].pred_subclass_prob, "subclass_prob_nms"
)
return pred_instances, {}