in easycv/models/detection/detectors/dino/dino_head.py [0:0]
def forward(self,
feats,
img_metas,
query_embed=None,
tgt=None,
attn_mask=None,
dn_meta=None):
"""Forward function.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
- all_cls_scores_list (list[Tensor]): Classification scores \
for each scale level. Each is a 4D-tensor with shape \
[nb_dec, bs, num_query, cls_out_channels]. Note \
`cls_out_channels` should includes background.
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
outputs for each scale level. Each is a 4D-tensor with \
normalized coordinate format (cx, cy, w, h) and shape \
[nb_dec, bs, num_query, 4].
"""
# construct binary masks which used for the transformer.
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.
bs = feats[0].size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
img_masks = feats[0].new_ones((bs, input_img_h, input_img_w))
for img_id in range(bs):
img_h, img_w, _ = img_metas[img_id]['img_shape']
img_masks[img_id, :img_h, :img_w] = 0
srcs = []
masks = []
poss = []
for l, src in enumerate(feats):
mask = F.interpolate(
img_masks[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
# position encoding
pos_l = self.positional_encoding(mask) # [bs, embed_dim, h, w]
srcs.append(self.input_proj[l](src))
masks.append(mask)
poss.append(pos_l)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](feats[-1])
else:
src = self.input_proj[l](srcs[-1])
mask = F.interpolate(
img_masks[None].float(),
size=src.shape[-2:]).to(torch.bool)[0]
# position encoding
pos_l = self.positional_encoding(mask) # [bs, embed_dim, h, w]
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, query_embed, poss, tgt, attn_mask)
# In case num object=0
hs[0] += self.label_enc.weight[0, 0] * 0.0
# deformable-detr-like anchor update
# reference_before_sigmoid = inverse_sigmoid(reference[:-1]) # n_dec, bs, nq, 4
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(
layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# outputs_class = self.class_embed(hs)
outputs_class = torch.stack([
layer_cls_embed(layer_hs)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
])
outputs_center_list = None
if self.use_centerness:
outputs_center_list = torch.stack([
layer_center_embed(layer_hs)
for layer_center_embed, layer_hs in zip(self.center_embed, hs)
])
outputs_iou_list = None
if self.use_iouaware:
outputs_iou_list = torch.stack([
layer_iou_embed(layer_hs)
for layer_iou_embed, layer_hs in zip(self.iou_embed, hs)
])
reference = torch.stack(reference)[:-1][..., :2]
if self.dn_number > 0 and dn_meta is not None:
outputs_class, outputs_coord_list, outputs_center_list, outputs_iou_list, reference = cdn_post_process(
outputs_class, outputs_coord_list, dn_meta, self._set_aux_loss,
outputs_center_list, outputs_iou_list, reference)
out = {
'pred_logits':
outputs_class[-1],
'pred_boxes':
outputs_coord_list[-1],
'pred_centers':
outputs_center_list[-1]
if outputs_center_list is not None else None,
'pred_ious':
outputs_iou_list[-1] if outputs_iou_list is not None else None,
'refpts':
reference[-1],
}
out['aux_outputs'] = self._set_aux_loss(outputs_class,
outputs_coord_list,
outputs_center_list,
outputs_iou_list, reference)
# for encoder output
if hs_enc is not None:
# prepare intermediate outputs
interm_coord = ref_enc[-1]
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
if self.use_centerness:
interm_center = self.transformer.enc_out_center_embed(
hs_enc[-1])
if self.use_iouaware:
interm_iou = self.transformer.enc_out_iou_embed(hs_enc[-1])
out['interm_outputs'] = {
'pred_logits': interm_class,
'pred_boxes': interm_coord,
'pred_centers': interm_center if self.use_centerness else None,
'pred_ious': interm_iou if self.use_iouaware else None,
'refpts': init_box_proposal[..., :2],
}
out['dn_meta'] = dn_meta
return out