in detic/modeling/meta_arch/custom_rcnn.py [0:0]
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
"""
Add ann_type
Ignore proposal loss when training with image labels
"""
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
ann_type = 'box'
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
if self.with_image_labels:
for inst, x in zip(gt_instances, batched_inputs):
inst._ann_type = x['ann_type']
inst._pos_category_ids = x['pos_category_ids']
ann_types = [x['ann_type'] for x in batched_inputs]
assert len(set(ann_types)) == 1
ann_type = ann_types[0]
if ann_type in ['prop', 'proptag']:
for t in gt_instances:
t.gt_classes *= 0
if self.fp16: # TODO (zhouxy): improve
with autocast():
features = self.backbone(images.tensor.half())
features = {k: v.float() for k, v in features.items()}
else:
features = self.backbone(images.tensor)
cls_features, cls_inds, caption_features = None, None, None
if self.with_caption and 'caption' in ann_type:
inds = [torch.randint(len(x['captions']), (1,))[0].item() \
for x in batched_inputs]
caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)]
caption_features = self.text_encoder(caps).float()
if self.sync_caption_batch:
caption_features = self._sync_caption_features(
caption_features, ann_type, len(batched_inputs))
if self.dynamic_classifier and ann_type != 'caption':
cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds
ind_with_bg = cls_inds[0].tolist() + [-1]
cls_features = self.roi_heads.box_predictor[
0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous()
classifier_info = cls_features, cls_inds, caption_features
proposals, proposal_losses = self.proposal_generator(
images, features, gt_instances)
if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']:
proposals, detector_losses = self.roi_heads(
images, features, proposals, gt_instances)
else:
proposals, detector_losses = self.roi_heads(
images, features, proposals, gt_instances,
ann_type=ann_type, classifier_info=classifier_info)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals)
losses = {}
losses.update(detector_losses)
if self.with_image_labels:
if ann_type in ['box', 'prop', 'proptag']:
losses.update(proposal_losses)
else: # ignore proposal loss for non-bbox data
losses.update({k: v * 0 for k, v in proposal_losses.items()})
else:
losses.update(proposal_losses)
if len(self.dataset_loss_weight) > 0:
dataset_sources = [x['dataset_source'] for x in batched_inputs]
assert len(set(dataset_sources)) == 1
dataset_source = dataset_sources[0]
for k in losses:
losses[k] *= self.dataset_loss_weight[dataset_source]
if self.return_proposal:
return proposals, losses
else:
return losses