def forward()

in easycv/models/detection3d/detectors/bevformer/bevformer_head.py [0:0]


    def forward(self, mlvl_feats, img_metas, prev_bev=None, only_bev=False):
        """Forward function.
        Args:
            mlvl_feats (tuple[Tensor]): Features from the upstream
                network, each is a 5D-tensor with shape
                (B, N, C, H, W).
            prev_bev: previous bev featues
            only_bev: only compute BEV features with encoder.
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head, \
                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
                cls_out_channels should includes background.
            all_bbox_preds (Tensor): Sigmoid outputs from the regression \
                head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
                Shape [nb_dec, bs, num_query, 9].
        """

        bs, num_cam, _, _, _ = mlvl_feats[0].shape
        dtype = mlvl_feats[0].dtype
        object_query_embeds = self.query_embedding.weight.to(dtype)
        bev_queries = self.bev_embedding.weight.to(dtype)

        bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                               device=bev_queries.device).to(dtype)
        bev_pos = self.positional_encoding(bev_mask).to(dtype)

        if only_bev:  # only use encoder to obtain BEV features, TODO: refine the workaround
            return self.transformer.get_bev_features(
                mlvl_feats,
                bev_queries,
                self.bev_h,
                self.bev_w,
                grid_length=(self.real_h / self.bev_h,
                             self.real_w / self.bev_w),
                bev_pos=bev_pos,
                img_metas=img_metas,
                prev_bev=prev_bev,
            )
        else:
            # make attn mask for one2many task
            self_attn_mask = torch.zeros([
                self.num_query,
                self.num_query,
            ]).bool().to(bev_queries.device)
            self_attn_mask[self.num_query_one2one:,
                           0:self.num_query_one2one, ] = True
            self_attn_mask[0:self.num_query_one2one,
                           self.num_query_one2one:, ] = True

            outputs = self.transformer(
                mlvl_feats,
                bev_queries,
                object_query_embeds,
                self.bev_h,
                self.bev_w,
                grid_length=(self.real_h / self.bev_h,
                             self.real_w / self.bev_w),
                bev_pos=bev_pos,
                reg_branches=self.reg_branches
                if self.with_box_refine else None,  # noqa:E501
                cls_branches=self.cls_branches if self.as_two_stage else None,
                img_metas=img_metas,
                prev_bev=prev_bev,
                attn_mask=self_attn_mask)

        bev_embed, hs, init_reference, inter_references = outputs
        hs = hs.permute(0, 2, 1, 3)
        outputs_classes = []
        outputs_coords = []
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.cls_branches[lvl](hs[lvl])
            tmp = self.reg_branches[lvl](hs[lvl])

            # TODO: check the shape of reference
            assert reference.shape[-1] == 3
            # tmp: torch.Size([1, 900, 10])
            # tmp[..., 0:2] += reference[..., 0:2]
            # tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
            # tmp[..., 4:5] += reference[..., 2:3]
            # tmp[..., 4:5] = tmp[..., 4:5].sigmoid()

            # tmp[..., 0:1] = (
            #     tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) +
            #     self.pc_range[0])
            # tmp[..., 1:2] = (
            #     tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) +
            #     self.pc_range[1])
            # tmp[..., 4:5] = (
            #     tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) +
            #     self.pc_range[2])

            # remove inplace operation, metric may incorrect when using blade
            tmp_0_2 = tmp[..., 0:2]
            tmp_0_2_add_reference = tmp_0_2 + reference[..., 0:2]
            tmp_0_2_add_reference = tmp_0_2_add_reference.sigmoid()
            tmp_4_5 = tmp[..., 4:5]
            tmp_4_5_add_reference = tmp_4_5 + reference[..., 2:3]
            tmp_4_5_add_reference = tmp_4_5_add_reference.sigmoid()
            tmp_0_1 = tmp_0_2_add_reference[..., 0:1]
            tmp_0_1_new = (
                tmp_0_1 * (self.pc_range[3] - self.pc_range[0]) +
                self.pc_range[0])
            tmp_1_2 = tmp_0_2_add_reference[..., 1:2]
            tmp_1_2_new = (
                tmp_1_2 * (self.pc_range[4] - self.pc_range[1]) +
                self.pc_range[1])
            tmp_4_5_new = (
                tmp_4_5_add_reference * (self.pc_range[5] - self.pc_range[2]) +
                self.pc_range[2])

            tmp_2_4 = tmp[..., 2:4]
            tmp_5_10 = tmp[..., 5:10]
            tmp = torch.cat(
                [tmp_0_1_new, tmp_1_2_new, tmp_2_4, tmp_4_5_new, tmp_5_10],
                dim=-1)

            # TODO: check if using sigmoid
            outputs_coord = tmp
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)

        outputs_classes = torch.stack(outputs_classes)
        outputs_coords = torch.stack(outputs_coords)

        outs = {
            'bev_embed': bev_embed,
            'all_cls_scores':
            outputs_classes[:, :, :self.num_query_one2one, :],
            'all_bbox_preds': outputs_coords[:, :, :self.num_query_one2one, :],
            'enc_cls_scores': None,
            'enc_bbox_preds': None,
        }

        if self.num_query_one2many > 0:
            outs['all_cls_scores_aux'] = outputs_classes[:, :, self.
                                                         num_query_one2one:, :]
            outs['all_bbox_preds_aux'] = outputs_coords[:, :, self.
                                                        num_query_one2one:, :]

        return outs