in easycv/models/detection3d/detectors/bevformer/bevformer_head.py [0:0]
def forward(self, mlvl_feats, img_metas, prev_bev=None, only_bev=False):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
"""
bs, num_cam, _, _, _ = mlvl_feats[0].shape
dtype = mlvl_feats[0].dtype
object_query_embeds = self.query_embedding.weight.to(dtype)
bev_queries = self.bev_embedding.weight.to(dtype)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
device=bev_queries.device).to(dtype)
bev_pos = self.positional_encoding(bev_mask).to(dtype)
if only_bev: # only use encoder to obtain BEV features, TODO: refine the workaround
return self.transformer.get_bev_features(
mlvl_feats,
bev_queries,
self.bev_h,
self.bev_w,
grid_length=(self.real_h / self.bev_h,
self.real_w / self.bev_w),
bev_pos=bev_pos,
img_metas=img_metas,
prev_bev=prev_bev,
)
else:
# make attn mask for one2many task
self_attn_mask = torch.zeros([
self.num_query,
self.num_query,
]).bool().to(bev_queries.device)
self_attn_mask[self.num_query_one2one:,
0:self.num_query_one2one, ] = True
self_attn_mask[0:self.num_query_one2one,
self.num_query_one2one:, ] = True
outputs = self.transformer(
mlvl_feats,
bev_queries,
object_query_embeds,
self.bev_h,
self.bev_w,
grid_length=(self.real_h / self.bev_h,
self.real_w / self.bev_w),
bev_pos=bev_pos,
reg_branches=self.reg_branches
if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None,
img_metas=img_metas,
prev_bev=prev_bev,
attn_mask=self_attn_mask)
bev_embed, hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.cls_branches[lvl](hs[lvl])
tmp = self.reg_branches[lvl](hs[lvl])
# TODO: check the shape of reference
assert reference.shape[-1] == 3
# tmp: torch.Size([1, 900, 10])
# tmp[..., 0:2] += reference[..., 0:2]
# tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
# tmp[..., 4:5] += reference[..., 2:3]
# tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
# tmp[..., 0:1] = (
# tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) +
# self.pc_range[0])
# tmp[..., 1:2] = (
# tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) +
# self.pc_range[1])
# tmp[..., 4:5] = (
# tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) +
# self.pc_range[2])
# remove inplace operation, metric may incorrect when using blade
tmp_0_2 = tmp[..., 0:2]
tmp_0_2_add_reference = tmp_0_2 + reference[..., 0:2]
tmp_0_2_add_reference = tmp_0_2_add_reference.sigmoid()
tmp_4_5 = tmp[..., 4:5]
tmp_4_5_add_reference = tmp_4_5 + reference[..., 2:3]
tmp_4_5_add_reference = tmp_4_5_add_reference.sigmoid()
tmp_0_1 = tmp_0_2_add_reference[..., 0:1]
tmp_0_1_new = (
tmp_0_1 * (self.pc_range[3] - self.pc_range[0]) +
self.pc_range[0])
tmp_1_2 = tmp_0_2_add_reference[..., 1:2]
tmp_1_2_new = (
tmp_1_2 * (self.pc_range[4] - self.pc_range[1]) +
self.pc_range[1])
tmp_4_5_new = (
tmp_4_5_add_reference * (self.pc_range[5] - self.pc_range[2]) +
self.pc_range[2])
tmp_2_4 = tmp[..., 2:4]
tmp_5_10 = tmp[..., 5:10]
tmp = torch.cat(
[tmp_0_1_new, tmp_1_2_new, tmp_2_4, tmp_4_5_new, tmp_5_10],
dim=-1)
# TODO: check if using sigmoid
outputs_coord = tmp
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
outs = {
'bev_embed': bev_embed,
'all_cls_scores':
outputs_classes[:, :, :self.num_query_one2one, :],
'all_bbox_preds': outputs_coords[:, :, :self.num_query_one2one, :],
'enc_cls_scores': None,
'enc_bbox_preds': None,
}
if self.num_query_one2many > 0:
outs['all_cls_scores_aux'] = outputs_classes[:, :, self.
num_query_one2one:, :]
outs['all_bbox_preds_aux'] = outputs_coords[:, :, self.
num_query_one2one:, :]
return outs