easycv/models/detection3d/detectors/bevformer/transformer.py (638 lines of code) (raw):

# Modified from https://github.com/fundamentalvision/BEVFormer. # Copyright (c) Alibaba, Inc. and its affiliates. import copy import warnings import numpy as np import torch import torch.nn as nn from mmcv import ConfigDict from mmcv.cnn import build_norm_layer, xavier_init from mmcv.runner import auto_fp16, force_fp32 from mmcv.runner.base_module import BaseModule, ModuleList from mmcv.utils import TORCH_VERSION, digit_version from torch.nn.init import normal_ from torchvision.transforms.functional import rotate from easycv.models.builder import (build_attention, build_feedforward_network, build_transformer_layer_sequence) from easycv.models.detection.utils.misc import inverse_sigmoid from easycv.models.registry import (POSITIONAL_ENCODING, TRANSFORMER, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) from easycv.models.utils.transformer import (BaseTransformerLayer, TransformerLayerSequence) from . import (CustomMSDeformableAttention, MSDeformableAttention3D, TemporalSelfAttention) from .attentions.spatial_cross_attention import SpatialCrossAttention @torch.jit.script def _rotate(img: torch.Tensor, angle: torch.Tensor, center: torch.Tensor): """torch.jit.trace does not support torchvision.rotate""" img = rotate( img, float(angle.item()), center=[int(center[0].item()), int(center[1].item())]) return img @TRANSFORMER_LAYER.register_module() class DetrTransformerDecoderLayer(BaseTransformerLayer): """Implements decoder layer in DETR transformer. Args: attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): Configs for self_attention or cross_attention, the order should be consistent with it in `operation_order`. If it is a dict, it would be expand to the number of attention in `operation_order`. ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for FFN, The order of the configs in the list should be consistent with corresponding ffn in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Default:None norm_cfg (dict): Config dict for normalization layer. Default: `LN`. """ def __init__(self, attn_cfgs, ffn_cfgs, operation_order=None, norm_cfg=dict(type='LN'), **kwargs): super(DetrTransformerDecoderLayer, self).__init__( attn_cfgs=attn_cfgs, ffn_cfgs=ffn_cfgs, operation_order=operation_order, norm_cfg=norm_cfg, **kwargs) assert len(operation_order) == 6 assert set(operation_order) == set( ['self_attn', 'norm', 'cross_attn', 'ffn']) @TRANSFORMER_LAYER.register_module() class BEVFormerLayer(BaseModule): """Implements decoder layer in DETR transformer. Args: attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for `self_attention` or `cross_attention` modules, The order of the configs in the list should be consistent with corresponding attentions in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. Default: None. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Support `prenorm` when you specifying first element as `norm`. Default:None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): Configs for FFN, The order of the configs in the list should be consistent with corresponding ffn in operation_order. If it is a dict, all of the attention modules in operation_order will be built with this config. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, attn_cfgs, operation_order=None, norm_cfg=dict(type='LN'), ffn_cfgs=dict( type='FFN', embed_dims=256, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ), batch_first=True, init_cfg=None, adapt_jit=False, **kwargs): super(BEVFormerLayer, self).__init__(init_cfg) self.batch_first = batch_first assert set(operation_order) & set( ['self_attn', 'norm', 'ffn', 'cross_attn']) == \ set(operation_order), f'The operation_order of' \ f' {self.__class__.__name__} should ' \ f'contains all four operation type ' \ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" num_attn = operation_order.count('self_attn') + operation_order.count( 'cross_attn') if isinstance(attn_cfgs, dict): attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] else: assert num_attn == len(attn_cfgs), f'The length ' \ f'of attn_cfg {num_attn} is ' \ f'not consistent with the number of attention' \ f'in operation_order {operation_order}.' self.num_attn = num_attn self.operation_order = operation_order self.norm_cfg = norm_cfg self.pre_norm = operation_order[0] == 'norm' self.attentions = ModuleList() index = 0 self.adapt_jit = adapt_jit for operation_name in operation_order: if operation_name in ['self_attn', 'cross_attn']: if 'batch_first' in attn_cfgs[index]: assert self.batch_first == attn_cfgs[index]['batch_first'] else: attn_cfgs[index]['batch_first'] = self.batch_first attention = build_attention(attn_cfgs[index]) # for export jit model if self.adapt_jit and isinstance(attention, SpatialCrossAttention): attention = torch.jit.script(attention) # Some custom attentions used as `self_attn` # or `cross_attn` can have different behavior. attention.operation_name = operation_name self.attentions.append(attention) index += 1 self.embed_dims = self.attentions[0].embed_dims self.ffns = ModuleList() num_ffns = operation_order.count('ffn') if isinstance(ffn_cfgs, dict): ffn_cfgs = ConfigDict(ffn_cfgs) if isinstance(ffn_cfgs, dict): ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] assert len(ffn_cfgs) == num_ffns for ffn_index in range(num_ffns): if 'embed_dims' not in ffn_cfgs[ffn_index]: ffn_cfgs['embed_dims'] = self.embed_dims else: assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index])) self.norms = ModuleList() num_norms = operation_order.count('norm') for _ in range(num_norms): self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) assert len(operation_order) == 6 assert set(operation_order) == set( ['self_attn', 'norm', 'cross_attn', 'ffn']) def forward(self, query, key=None, value=None, bev_pos=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, ref_2d=None, ref_3d=None, bev_h=None, bev_w=None, reference_points_cam=None, mask=None, spatial_shapes=None, level_start_index=None, prev_bev=None, **kwargs): """Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . value (Tensor): The value tensor with same shape as `key`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor] | None): 2D Tensor used in calculation of corresponding attention. The length of it should equal to the number of `attention` in `operation_order`. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn` layer. Defaults to None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ norm_index = 0 attn_index = 0 ffn_index = 0 identity = query if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor): attn_masks = [ copy.deepcopy(attn_masks) for _ in range(self.num_attn) ] warnings.warn(f'Use same attn_mask in all attentions in ' f'{self.__class__.__name__} ') else: assert len(attn_masks) == self.num_attn, f'The length of ' \ f'attn_masks {len(attn_masks)} must be equal ' \ f'to the number of attention in ' \ f'operation_order {self.num_attn}' for layer in self.operation_order: # temporal self attention if layer == 'self_attn': query = self.attentions[attn_index]( query=query, key=prev_bev, value=prev_bev, identity=identity if self.pre_norm else None, query_pos=bev_pos, key_padding_mask=query_key_padding_mask, reference_points=ref_2d, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), level_start_index=torch.tensor([0], device=query.device), ) attn_index += 1 identity = query elif layer == 'norm': # fix fp16 dtype = query.dtype query = self.norms[norm_index](query) query = query.to(dtype) norm_index += 1 # spaital cross attention elif layer == 'cross_attn': query = self.attentions[attn_index]( query=query, key=key, value=value, residual=identity if self.pre_norm else None, query_pos=query_pos, reference_points=ref_3d, reference_points_cam=reference_points_cam, bev_mask=kwargs.get('bev_mask'), key_padding_mask=key_padding_mask, spatial_shapes=spatial_shapes, level_start_index=level_start_index, ) attn_index += 1 identity = query elif layer == 'ffn': query = self.ffns[ffn_index]( query, identity if self.pre_norm else None) ffn_index += 1 return query @TRANSFORMER_LAYER_SEQUENCE.register_module() class Detr3DTransformerDecoder(TransformerLayerSequence): """Implements the decoder in DETR3D transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. coder_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, return_intermediate=False, **kwargs): super(Detr3DTransformerDecoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate def forward(self, query, *args, reference_points=None, reg_branches=None, key_padding_mask=None, attn_mask=None, **kwargs): """Forward function for `Detr3DTransformerDecoder`. Args: query (Tensor): Input query with shape `(num_query, bs, embed_dims)`. reference_points (Tensor): The reference points of offset. has shape (bs, num_query, 4) when as_two_stage, otherwise has shape ((bs, num_query, 2). reg_branch: (obj:`nn.ModuleList`): Used for refining the regression results. Only would be passed when with_box_refine is True, otherwise would be passed a `None`. Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ output = query intermediate = [] intermediate_reference_points = [] for lid, layer in enumerate(self.layers): reference_points_input = reference_points[..., :2].unsqueeze( 2) # BS NUM_QUERY NUM_LEVEL 2 output = layer( output, *args, reference_points=reference_points_input, attn_masks=[attn_mask] * layer.num_attn, key_padding_mask=key_padding_mask, **kwargs) output = output.permute(1, 0, 2) if reg_branches is not None: tmp = reg_branches[lid](output) assert reference_points.shape[-1] == 3 # new_reference_points = torch.zeros_like( # reference_points) # torch.Size([1, 900, 3]) # new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid( # reference_points[..., :2], eps=1e-5) # new_reference_points[..., # 2:3] = tmp[..., 4:5] + inverse_sigmoid( # reference_points[..., 2:3], eps=1e-5) # new_reference_points = new_reference_points.sigmoid() # reference_points = new_reference_points.detach() # remove inplace operation, metric may incorrect when using blade new_reference_points_0_2 = tmp[..., :2] + inverse_sigmoid( reference_points[..., :2], eps=1e-5) new_reference_points_2_3 = tmp[..., 4:5] + inverse_sigmoid( reference_points[..., 2:3], eps=1e-5) new_reference_points = torch.cat( [new_reference_points_0_2, new_reference_points_2_3], dim=-1) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() output = output.permute(1, 0, 2) if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points @TRANSFORMER_LAYER_SEQUENCE.register_module() class BEVFormerEncoder(TransformerLayerSequence): """ Attention with both self and cross Implements the decoder in DETR transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. coder_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, pc_range=None, num_points_in_pillar=4, return_intermediate=False, dataset_type='nuscenes', **kwargs): super(BEVFormerEncoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate self.num_points_in_pillar = num_points_in_pillar self.pc_range = pc_range @staticmethod def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float): """Get the reference points used in SCA and TSA. Args: H, W: spatial shape of bev. Z: hight of pillar. D: sample D points uniformly from each pillar. device (obj:`device`): The device where reference_points should be. Returns: Tensor: reference points used in decoder, has \ shape (bs, num_keys, num_levels, 2). """ # reference points in 3D space, used in spatial cross-attention (SCA) if dim == '3d': zs = torch.linspace( 0.5, Z - 0.5, num_points_in_pillar, dtype=dtype, device=device).view(-1, 1, 1).expand( num_points_in_pillar, H, W) / Z xs = torch.linspace( 0.5, W - 0.5, W, dtype=dtype, device=device).view( 1, 1, W).expand(num_points_in_pillar, H, W) / W ys = torch.linspace( 0.5, H - 0.5, H, dtype=dtype, device=device).view( 1, H, 1).expand(num_points_in_pillar, H, W) / H ref_3d = torch.stack((xs, ys, zs), -1) ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) return ref_3d # reference points on 2D bev plane, used in temporal self-attention (TSA). elif dim == '2d': ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, H - 0.5, H, dtype=dtype, device=device), torch.linspace(0.5, W - 0.5, W, dtype=dtype, device=device)) ref_y = ref_y.reshape(-1)[None] / H ref_x = ref_x.reshape(-1)[None] / W ref_2d = torch.stack((ref_x, ref_y), -1) ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2) return ref_2d # This function must use fp32!!! @force_fp32(apply_to=('reference_points', 'img_metas')) def point_sampling(self, reference_points, pc_range, img_metas): lidar2img = torch.stack([meta['lidar2img'] for meta in img_metas ]).to(reference_points.dtype) # (B, N, 4, 4) reference_points = reference_points.clone() reference_points[..., 0:1] = reference_points[..., 0:1] * \ (pc_range[3] - pc_range[0]) + pc_range[0] reference_points[..., 1:2] = reference_points[..., 1:2] * \ (pc_range[4] - pc_range[1]) + pc_range[1] reference_points[..., 2:3] = reference_points[..., 2:3] * \ (pc_range[5] - pc_range[2]) + pc_range[2] reference_points = torch.cat( (reference_points, torch.ones_like(reference_points[..., :1])), -1) reference_points = reference_points.permute(1, 0, 2, 3) D, B, num_query = reference_points.size()[:3] num_cam = lidar2img.size(1) reference_points = reference_points.view(D, B, 1, num_query, 4).repeat( 1, 1, num_cam, 1, 1).unsqueeze(-1) lidar2img = lidar2img.view(1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1) reference_points_cam = torch.matmul( lidar2img.to(torch.float32), reference_points.to(torch.float32)).squeeze(-1) eps = 1e-5 bev_mask = (reference_points_cam[..., 2:3] > eps) reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps) reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0] bev_mask = ( bev_mask & (reference_points_cam[..., 1:2] > 0.0) & (reference_points_cam[..., 1:2] < 1.0) & (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 0:1] > 0.0)) if digit_version(TORCH_VERSION) >= digit_version('1.8'): bev_mask = torch.nan_to_num(bev_mask) else: bev_mask = bev_mask.new_tensor( np.nan_to_num(bev_mask.cpu().numpy())) reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1) return reference_points_cam, bev_mask @auto_fp16() def forward(self, bev_query, key, value, *args, bev_h=None, bev_w=None, bev_pos=None, spatial_shapes=None, level_start_index=None, valid_ratios=None, prev_bev=None, shift=0., **kwargs): """Forward function for `TransformerDecoder`. Args: bev_query (Tensor): Input BEV query with shape `(num_query, bs, embed_dims)`. key & value (Tensor): Input multi-cameta features with shape (num_cam, num_value, bs, embed_dims) reference_points (Tensor): The reference points of offset. has shape (bs, num_query, 4) when as_two_stage, otherwise has shape ((bs, num_query, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ output = bev_query intermediate = [] ref_3d = self.get_reference_points( bev_h, bev_w, self.pc_range[5] - self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype) ref_2d = self.get_reference_points( bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype) reference_points_cam, bev_mask = self.point_sampling( ref_3d, self.pc_range, kwargs['img_metas']) # bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper. shift_ref_2d = ref_2d # .clone() shift_ref_2d += shift[:, None, None, :] # (num_query, bs, embed_dims) -> (bs, num_query, embed_dims) bev_query = bev_query.permute(1, 0, 2) bev_pos = bev_pos.permute(1, 0, 2) bs, len_bev, num_bev_level, _ = ref_2d.shape if prev_bev is not None: prev_bev = prev_bev.permute(1, 0, 2) prev_bev = torch.stack([prev_bev, bev_query], 1).reshape(bs * 2, len_bev, -1) hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(bs * 2, len_bev, num_bev_level, 2) else: hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(bs * 2, len_bev, num_bev_level, 2) for lid, layer in enumerate(self.layers): output = layer( bev_query, key, value, *args, bev_pos=bev_pos, ref_2d=hybird_ref_2d, ref_3d=ref_3d, bev_h=bev_h, bev_w=bev_w, spatial_shapes=spatial_shapes, level_start_index=level_start_index, reference_points_cam=reference_points_cam, bev_mask=bev_mask, prev_bev=prev_bev, **kwargs) bev_query = output if self.return_intermediate: intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output @TRANSFORMER.register_module() class PerceptionTransformer(BaseModule): """Implements the Detr3D transformer. Args: as_two_stage (bool): Generate query from encoder features. Default: False. num_feature_levels (int): Number of feature maps from FPN: Default: 4. two_stage_num_proposals (int): Number of proposals when set `as_two_stage` as True. Default: 300. """ def __init__(self, num_feature_levels=4, num_cams=6, two_stage_num_proposals=300, encoder=None, decoder=None, embed_dims=256, rotate_prev_bev=True, use_shift=True, use_can_bus=True, can_bus_norm=True, use_cams_embeds=True, rotate_center=[100, 100], **kwargs): super(PerceptionTransformer, self).__init__(**kwargs) self.encoder = build_transformer_layer_sequence(encoder) self.decoder = build_transformer_layer_sequence(decoder) self.embed_dims = embed_dims self.num_feature_levels = num_feature_levels self.num_cams = num_cams self.rotate_prev_bev = rotate_prev_bev self.use_shift = use_shift self.use_can_bus = use_can_bus self.can_bus_norm = can_bus_norm self.use_cams_embeds = use_cams_embeds self.two_stage_num_proposals = two_stage_num_proposals self.init_layers() self.rotate_center = rotate_center def init_layers(self): """Initialize layers of the Detr3DTransformer.""" self.level_embeds = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims)) self.cams_embeds = nn.Parameter( torch.Tensor(self.num_cams, self.embed_dims)) self.reference_points = nn.Linear(self.embed_dims, 3) self.can_bus_mlp = nn.Sequential( nn.Linear(18, self.embed_dims // 2), nn.ReLU(inplace=True), nn.Linear(self.embed_dims // 2, self.embed_dims), nn.ReLU(inplace=True), ) if self.can_bus_norm: self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims)) def init_weights(self): """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \ or isinstance(m, CustomMSDeformableAttention): m.init_weights() normal_(self.level_embeds) normal_(self.cams_embeds) xavier_init(self.reference_points, distribution='uniform', bias=0.) xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.) @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'prev_bev', 'bev_pos')) def get_bev_features(self, mlvl_feats, bev_queries, bev_h, bev_w, grid_length=[0.512, 0.512], bev_pos=None, prev_bev=None, **kwargs): """ obtain bev features. """ bs = mlvl_feats[0].size(0) bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1) bev_pos = bev_pos.flatten(2).permute(2, 0, 1) # obtain rotation angle and shift with ego motion delta_x = torch.stack( [each['can_bus'][0] for each in kwargs['img_metas']]) delta_y = torch.stack( [each['can_bus'][1] for each in kwargs['img_metas']]) ego_angle = torch.stack([ each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas'] ]) grid_length_y = grid_length[0] grid_length_x = grid_length[1] translation_length = torch.sqrt(delta_x**2 + delta_y**2) translation_angle = torch.atan2(delta_y, delta_x) / np.pi * 180 bev_angle = ego_angle - translation_angle shift_y = translation_length * \ torch.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h shift_x = translation_length * \ torch.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w if not self.use_shift: shift_y = shift_y.new_zeros(shift_y.size()) shift_x = shift_x.new_zeros(shift_y.size()) shift = torch.stack([shift_x, shift_y]).permute(1, 0).to(bev_queries.dtype) if prev_bev is not None: if prev_bev.shape[1] == bev_h * bev_w: prev_bev = prev_bev.permute(1, 0, 2) if self.rotate_prev_bev: for i in range(bs): # num_prev_bev = prev_bev.size(1) rotation_angle = kwargs['img_metas'][i]['can_bus'][-1] tmp_prev_bev = prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1) tmp_prev_bev = _rotate( tmp_prev_bev, rotation_angle, center=torch.tensor(self.rotate_center)) tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape( bev_h * bev_w, 1, -1) prev_bev[:, i] = tmp_prev_bev[:, 0] # add can bus signals can_bus = torch.stack([ each['can_bus'] for each in kwargs['img_metas'] ]).to(bev_queries.dtype) can_bus = self.can_bus_mlp(can_bus)[None, :, :] # fix fp16 can_bus = can_bus.to(bev_queries.dtype) if self.use_can_bus: bev_queries = bev_queries + can_bus feat_flatten = [] spatial_shapes = [] for lvl, feat in enumerate(mlvl_feats): bs, num_cam, c, h, w = feat.shape spatial_shape = (h, w) feat = feat.flatten(3).permute(1, 0, 3, 2) if self.use_cams_embeds: feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype) feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to( feat.dtype) spatial_shapes.append(spatial_shape) feat_flatten.append(feat) feat_flatten = torch.cat(feat_flatten, 2) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=bev_pos.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) feat_flatten = feat_flatten.permute( 0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims) bev_embed = self.encoder( bev_queries, feat_flatten, feat_flatten, bev_h=bev_h, bev_w=bev_w, bev_pos=bev_pos, spatial_shapes=spatial_shapes, level_start_index=level_start_index, prev_bev=prev_bev, shift=shift, **kwargs) return bev_embed @auto_fp16( apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos')) def forward(self, mlvl_feats, bev_queries, object_query_embed, bev_h, bev_w, grid_length=[0.512, 0.512], bev_pos=None, reg_branches=None, cls_branches=None, prev_bev=None, attn_mask=None, **kwargs): """Forward function for `Detr3DTransformer`. Args: mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, num_cams, embed_dims, h, w]. bev_queries (Tensor): (bev_h*bev_w, c) bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w) object_query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. reg_branches (obj:`nn.ModuleList`): Regression heads for feature maps from each decoder layer. Only would be passed when `with_box_refine` is True. Default to None. Returns: tuple[Tensor]: results of decoder containing the following tensor. - bev_embed: BEV features - inter_states: Outputs from decoder. If return_intermediate_dec is True output has shape \ (num_dec_layers, bs, num_query, embed_dims), else has \ shape (1, bs, num_query, embed_dims). - init_reference_out: The initial value of reference \ points, has shape (bs, num_queries, 4). - inter_references_out: The internal value of reference \ points in decoder, has shape \ (num_dec_layers, bs,num_query, embed_dims) - enc_outputs_class: The classification score of \ proposals generated from \ encoder's feature maps, has shape \ (batch, h*w, num_classes). \ Only would be returned when `as_two_stage` is True, \ otherwise None. - enc_outputs_coord_unact: The regression results \ generated from encoder's feature maps., has shape \ (batch, h*w, 4). Only would \ be returned when `as_two_stage` is True, \ otherwise None. """ bev_embed = self.get_bev_features( mlvl_feats, bev_queries, bev_h, bev_w, grid_length=grid_length, bev_pos=bev_pos, prev_bev=prev_bev, **kwargs) # bev_embed shape: bs, bev_h*bev_w, embed_dims bs = mlvl_feats[0].size(0) query_pos, query = torch.split( object_query_embed, self.embed_dims, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos) reference_points = reference_points.sigmoid() init_reference_out = reference_points query = query.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) bev_embed = bev_embed.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=bev_embed, query_pos=query_pos, reference_points=reference_points, attn_mask=attn_mask, reg_branches=reg_branches, cls_branches=cls_branches, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), level_start_index=torch.tensor([0], device=query.device), **kwargs) inter_references_out = inter_references return bev_embed, inter_states, init_reference_out, inter_references_out @POSITIONAL_ENCODING.register_module() class LearnedPositionalEncoding(BaseModule): """Position embedding with learnable embedding weights. Args: num_feats (int): The feature dimension for each position along x-axis or y-axis. The final returned dimension for each position is 2 times of this value. row_num_embed (int, optional): The dictionary size of row embeddings. Default 50. col_num_embed (int, optional): The dictionary size of col embeddings. Default 50. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type='Uniform', layer='Embedding')): super(LearnedPositionalEncoding, self).__init__(init_cfg) self.row_embed = nn.Embedding(row_num_embed, num_feats) self.col_embed = nn.Embedding(col_num_embed, num_feats) self.num_feats = num_feats self.row_num_embed = row_num_embed self.col_num_embed = col_num_embed def forward(self, mask): """Forward function for `LearnedPositionalEncoding`. Args: mask (Tensor): ByteTensor mask. Non-zero values representing ignored positions, while zero values means valid positions for this image. Shape [bs, h, w]. Returns: pos (Tensor): Returned position embedding with shape [bs, num_feats*2, h, w]. """ h, w = mask.shape[-2:] x = torch.arange(w, device=mask.device) y = torch.arange(h, device=mask.device) x_embed = self.col_embed(x) y_embed = self.row_embed(y) pos = torch.cat( (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( 1, w, 1)), dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) return pos def __repr__(self): """str: a string that describes the module""" repr_str = self.__class__.__name__ repr_str += f'(num_feats={self.num_feats}, ' repr_str += f'row_num_embed={self.row_num_embed}, ' repr_str += f'col_num_embed={self.col_num_embed})' return repr_str