in detic/modeling/roi_heads/res5_roi_heads.py [0:0]
def forward(self, images, features, proposals, targets=None,
ann_type='box', classifier_info=(None,None,None)):
'''
enable debug and image labels
classifier_info is shared across the batch
'''
if not self.save_debug:
del images
if self.training:
if ann_type in ['box']:
proposals = self.label_and_sample_proposals(
proposals, targets)
else:
proposals = self.get_top_proposals(proposals)
proposal_boxes = [x.proposal_boxes for x in proposals]
box_features = self._shared_roi_transform(
[features[f] for f in self.in_features], proposal_boxes
)
predictions = self.box_predictor(
box_features.mean(dim=[2, 3]),
classifier_info=classifier_info)
if self.add_feature_to_prop:
feats_per_image = box_features.mean(dim=[2, 3]).split(
[len(p) for p in proposals], dim=0)
for feat, p in zip(feats_per_image, proposals):
p.feat = feat
if self.training:
del features
if (ann_type != 'box'):
image_labels = [x._pos_category_ids for x in targets]
losses = self.box_predictor.image_label_losses(
predictions, proposals, image_labels,
classifier_info=classifier_info,
ann_type=ann_type)
else:
losses = self.box_predictor.losses(
(predictions[0], predictions[1]), proposals)
if self.with_image_labels:
assert 'image_loss' not in losses
losses['image_loss'] = predictions[0].new_zeros([1])[0]
if self.save_debug:
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
if ann_type != 'box':
image_labels = [x._pos_category_ids for x in targets]
else:
image_labels = [[] for x in targets]
debug_second_stage(
[denormalizer(x.clone()) for x in images],
targets, proposals=proposals,
save_debug=self.save_debug,
debug_show_name=self.debug_show_name,
vis_thresh=self.vis_thresh,
image_labels=image_labels,
save_debug_path=self.save_debug_path,
bgr=self.bgr)
return proposals, losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
if self.save_debug:
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
debug_second_stage(
[denormalizer(x.clone()) for x in images],
pred_instances, proposals=proposals,
save_debug=self.save_debug,
debug_show_name=self.debug_show_name,
vis_thresh=self.vis_thresh,
save_debug_path=self.save_debug_path,
bgr=self.bgr)
return pred_instances, {}