def forward()

in easycv/models/detection/detectors/dino/dino_head.py [0:0]


    def forward(self,
                feats,
                img_metas,
                query_embed=None,
                tgt=None,
                attn_mask=None,
                dn_meta=None):
        """Forward function.
        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.
            img_metas (list[dict]): List of image information.
        Returns:
            tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
                - all_cls_scores_list (list[Tensor]): Classification scores \
                    for each scale level. Each is a 4D-tensor with shape \
                    [nb_dec, bs, num_query, cls_out_channels]. Note \
                    `cls_out_channels` should includes background.
                - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
                    outputs for each scale level. Each is a 4D-tensor with \
                    normalized coordinate format (cx, cy, w, h) and shape \
                    [nb_dec, bs, num_query, 4].
        """
        # construct binary masks which used for the transformer.
        # NOTE following the official DETR repo, non-zero values representing
        # ignored positions, while zero values means valid positions.
        bs = feats[0].size(0)
        input_img_h, input_img_w = img_metas[0]['batch_input_shape']
        img_masks = feats[0].new_ones((bs, input_img_h, input_img_w))
        for img_id in range(bs):
            img_h, img_w, _ = img_metas[img_id]['img_shape']
            img_masks[img_id, :img_h, :img_w] = 0

        srcs = []
        masks = []
        poss = []
        for l, src in enumerate(feats):
            mask = F.interpolate(
                img_masks[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
            # position encoding
            pos_l = self.positional_encoding(mask)  # [bs, embed_dim, h, w]
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            poss.append(pos_l)
            assert mask is not None
        if self.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)
            for l in range(_len_srcs, self.num_feature_levels):
                if l == _len_srcs:
                    src = self.input_proj[l](feats[-1])
                else:
                    src = self.input_proj[l](srcs[-1])
                mask = F.interpolate(
                    img_masks[None].float(),
                    size=src.shape[-2:]).to(torch.bool)[0]
                # position encoding
                pos_l = self.positional_encoding(mask)  # [bs, embed_dim, h, w]
                srcs.append(src)
                masks.append(mask)
                poss.append(pos_l)

        hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
            srcs, masks, query_embed, poss, tgt, attn_mask)
        # In case num object=0
        hs[0] += self.label_enc.weight[0, 0] * 0.0

        # deformable-detr-like anchor update
        # reference_before_sigmoid = inverse_sigmoid(reference[:-1]) # n_dec, bs, nq, 4
        outputs_coord_list = []
        for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
                zip(reference[:-1], self.bbox_embed, hs)):
            layer_delta_unsig = layer_bbox_embed(layer_hs)
            layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(
                layer_ref_sig)
            layer_outputs_unsig = layer_outputs_unsig.sigmoid()
            outputs_coord_list.append(layer_outputs_unsig)
        outputs_coord_list = torch.stack(outputs_coord_list)

        # outputs_class = self.class_embed(hs)
        outputs_class = torch.stack([
            layer_cls_embed(layer_hs)
            for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
        ])

        outputs_center_list = None
        if self.use_centerness:
            outputs_center_list = torch.stack([
                layer_center_embed(layer_hs)
                for layer_center_embed, layer_hs in zip(self.center_embed, hs)
            ])

        outputs_iou_list = None
        if self.use_iouaware:
            outputs_iou_list = torch.stack([
                layer_iou_embed(layer_hs)
                for layer_iou_embed, layer_hs in zip(self.iou_embed, hs)
            ])

        reference = torch.stack(reference)[:-1][..., :2]
        if self.dn_number > 0 and dn_meta is not None:
            outputs_class, outputs_coord_list, outputs_center_list, outputs_iou_list, reference = cdn_post_process(
                outputs_class, outputs_coord_list, dn_meta, self._set_aux_loss,
                outputs_center_list, outputs_iou_list, reference)
        out = {
            'pred_logits':
            outputs_class[-1],
            'pred_boxes':
            outputs_coord_list[-1],
            'pred_centers':
            outputs_center_list[-1]
            if outputs_center_list is not None else None,
            'pred_ious':
            outputs_iou_list[-1] if outputs_iou_list is not None else None,
            'refpts':
            reference[-1],
        }

        out['aux_outputs'] = self._set_aux_loss(outputs_class,
                                                outputs_coord_list,
                                                outputs_center_list,
                                                outputs_iou_list, reference)

        # for encoder output
        if hs_enc is not None:
            # prepare intermediate outputs
            interm_coord = ref_enc[-1]
            interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
            if self.use_centerness:
                interm_center = self.transformer.enc_out_center_embed(
                    hs_enc[-1])
            if self.use_iouaware:
                interm_iou = self.transformer.enc_out_iou_embed(hs_enc[-1])
            out['interm_outputs'] = {
                'pred_logits': interm_class,
                'pred_boxes': interm_coord,
                'pred_centers': interm_center if self.use_centerness else None,
                'pred_ious': interm_iou if self.use_iouaware else None,
                'refpts': init_box_proposal[..., :2],
            }

        out['dn_meta'] = dn_meta

        return out