easycv/models/detection3d/detectors/bevformer/transformer.py (638 lines of code) (raw):
# Modified from https://github.com/fundamentalvision/BEVFormer.
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import warnings
import numpy as np
import torch
import torch.nn as nn
from mmcv import ConfigDict
from mmcv.cnn import build_norm_layer, xavier_init
from mmcv.runner import auto_fp16, force_fp32
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils import TORCH_VERSION, digit_version
from torch.nn.init import normal_
from torchvision.transforms.functional import rotate
from easycv.models.builder import (build_attention, build_feedforward_network,
build_transformer_layer_sequence)
from easycv.models.detection.utils.misc import inverse_sigmoid
from easycv.models.registry import (POSITIONAL_ENCODING, TRANSFORMER,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from easycv.models.utils.transformer import (BaseTransformerLayer,
TransformerLayerSequence)
from . import (CustomMSDeformableAttention, MSDeformableAttention3D,
TemporalSelfAttention)
from .attentions.spatial_cross_attention import SpatialCrossAttention
@torch.jit.script
def _rotate(img: torch.Tensor, angle: torch.Tensor, center: torch.Tensor):
"""torch.jit.trace does not support torchvision.rotate"""
img = rotate(
img,
float(angle.item()),
center=[int(center[0].item()),
int(center[1].item())])
return img
@TRANSFORMER_LAYER.register_module()
class DetrTransformerDecoderLayer(BaseTransformerLayer):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
"""
def __init__(self,
attn_cfgs,
ffn_cfgs,
operation_order=None,
norm_cfg=dict(type='LN'),
**kwargs):
super(DetrTransformerDecoderLayer, self).__init__(
attn_cfgs=attn_cfgs,
ffn_cfgs=ffn_cfgs,
operation_order=operation_order,
norm_cfg=norm_cfg,
**kwargs)
assert len(operation_order) == 6
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn'])
@TRANSFORMER_LAYER.register_module()
class BEVFormerLayer(BaseModule):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
attn_cfgs,
operation_order=None,
norm_cfg=dict(type='LN'),
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
batch_first=True,
init_cfg=None,
adapt_jit=False,
**kwargs):
super(BEVFormerLayer, self).__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.num_attn = num_attn
self.operation_order = operation_order
self.norm_cfg = norm_cfg
self.pre_norm = operation_order[0] == 'norm'
self.attentions = ModuleList()
index = 0
self.adapt_jit = adapt_jit
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
# for export jit model
if self.adapt_jit and isinstance(attention,
SpatialCrossAttention):
attention = torch.jit.script(attention)
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = ModuleList()
num_ffns = operation_order.count('ffn')
if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index]))
self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
assert len(operation_order) == 6
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn'])
def forward(self,
query,
key=None,
value=None,
bev_pos=None,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
ref_2d=None,
ref_3d=None,
bev_h=None,
bev_w=None,
reference_points_cam=None,
mask=None,
spatial_shapes=None,
level_start_index=None,
prev_bev=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order:
# temporal self attention
if layer == 'self_attn':
query = self.attentions[attn_index](
query=query,
key=prev_bev,
value=prev_bev,
identity=identity if self.pre_norm else None,
query_pos=bev_pos,
key_padding_mask=query_key_padding_mask,
reference_points=ref_2d,
spatial_shapes=torch.tensor([[bev_h, bev_w]],
device=query.device),
level_start_index=torch.tensor([0], device=query.device),
)
attn_index += 1
identity = query
elif layer == 'norm':
# fix fp16
dtype = query.dtype
query = self.norms[norm_index](query)
query = query.to(dtype)
norm_index += 1
# spaital cross attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query=query,
key=key,
value=value,
residual=identity if self.pre_norm else None,
query_pos=query_pos,
reference_points=ref_3d,
reference_points_cam=reference_points_cam,
bev_mask=kwargs.get('bev_mask'),
key_padding_mask=key_padding_mask,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
return query
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class Detr3DTransformerDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR3D transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args, return_intermediate=False, **kwargs):
super(Detr3DTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
def forward(self,
query,
*args,
reference_points=None,
reg_branches=None,
key_padding_mask=None,
attn_mask=None,
**kwargs):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = reference_points[..., :2].unsqueeze(
2) # BS NUM_QUERY NUM_LEVEL 2
output = layer(
output,
*args,
reference_points=reference_points_input,
attn_masks=[attn_mask] * layer.num_attn,
key_padding_mask=key_padding_mask,
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
assert reference_points.shape[-1] == 3
# new_reference_points = torch.zeros_like(
# reference_points) # torch.Size([1, 900, 3])
# new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
# reference_points[..., :2], eps=1e-5)
# new_reference_points[...,
# 2:3] = tmp[..., 4:5] + inverse_sigmoid(
# reference_points[..., 2:3], eps=1e-5)
# new_reference_points = new_reference_points.sigmoid()
# reference_points = new_reference_points.detach()
# remove inplace operation, metric may incorrect when using blade
new_reference_points_0_2 = tmp[..., :2] + inverse_sigmoid(
reference_points[..., :2], eps=1e-5)
new_reference_points_2_3 = tmp[..., 4:5] + inverse_sigmoid(
reference_points[..., 2:3], eps=1e-5)
new_reference_points = torch.cat(
[new_reference_points_0_2, new_reference_points_2_3],
dim=-1)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BEVFormerEncoder(TransformerLayerSequence):
"""
Attention with both self and cross
Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self,
*args,
pc_range=None,
num_points_in_pillar=4,
return_intermediate=False,
dataset_type='nuscenes',
**kwargs):
super(BEVFormerEncoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.num_points_in_pillar = num_points_in_pillar
self.pc_range = pc_range
@staticmethod
def get_reference_points(H,
W,
Z=8,
num_points_in_pillar=4,
dim='3d',
bs=1,
device='cuda',
dtype=torch.float):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of bev.
Z: hight of pillar.
D: sample D points uniformly from each pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
if dim == '3d':
zs = torch.linspace(
0.5, Z - 0.5, num_points_in_pillar,
dtype=dtype, device=device).view(-1, 1, 1).expand(
num_points_in_pillar, H, W) / Z
xs = torch.linspace(
0.5, W - 0.5, W, dtype=dtype, device=device).view(
1, 1, W).expand(num_points_in_pillar, H, W) / W
ys = torch.linspace(
0.5, H - 0.5, H, dtype=dtype, device=device).view(
1, H, 1).expand(num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
return ref_3d
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif dim == '2d':
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H - 0.5, H, dtype=dtype, device=device),
torch.linspace(0.5, W - 0.5, W, dtype=dtype, device=device))
ref_y = ref_y.reshape(-1)[None] / H
ref_x = ref_x.reshape(-1)[None] / W
ref_2d = torch.stack((ref_x, ref_y), -1)
ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
return ref_2d
# This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas):
lidar2img = torch.stack([meta['lidar2img'] for meta in img_metas
]).to(reference_points.dtype) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points[..., 0:1] = reference_points[..., 0:1] * \
(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2] * \
(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3] * \
(pc_range[5] - pc_range[2]) + pc_range[2]
reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = reference_points.permute(1, 0, 2, 3)
D, B, num_query = reference_points.size()[:3]
num_cam = lidar2img.size(1)
reference_points = reference_points.view(D, B, 1, num_query, 4).repeat(
1, 1, num_cam, 1, 1).unsqueeze(-1)
lidar2img = lidar2img.view(1, B, num_cam, 1, 4,
4).repeat(D, 1, 1, num_query, 1, 1)
reference_points_cam = torch.matmul(
lidar2img.to(torch.float32),
reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
bev_mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3],
torch.ones_like(reference_points_cam[..., 2:3]) * eps)
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
bev_mask = (
bev_mask & (reference_points_cam[..., 1:2] > 0.0)
& (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 0:1] > 0.0))
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
bev_mask = torch.nan_to_num(bev_mask)
else:
bev_mask = bev_mask.new_tensor(
np.nan_to_num(bev_mask.cpu().numpy()))
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
return reference_points_cam, bev_mask
@auto_fp16()
def forward(self,
bev_query,
key,
value,
*args,
bev_h=None,
bev_w=None,
bev_pos=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
prev_bev=None,
shift=0.,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = bev_query
intermediate = []
ref_3d = self.get_reference_points(
bev_h,
bev_w,
self.pc_range[5] - self.pc_range[2],
self.num_points_in_pillar,
dim='3d',
bs=bev_query.size(1),
device=bev_query.device,
dtype=bev_query.dtype)
ref_2d = self.get_reference_points(
bev_h,
bev_w,
dim='2d',
bs=bev_query.size(1),
device=bev_query.device,
dtype=bev_query.dtype)
reference_points_cam, bev_mask = self.point_sampling(
ref_3d, self.pc_range, kwargs['img_metas'])
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d = ref_2d # .clone()
shift_ref_2d += shift[:, None, None, :]
# (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
bev_query = bev_query.permute(1, 0, 2)
bev_pos = bev_pos.permute(1, 0, 2)
bs, len_bev, num_bev_level, _ = ref_2d.shape
if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack([prev_bev, bev_query],
1).reshape(bs * 2, len_bev, -1)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d],
1).reshape(bs * 2, len_bev,
num_bev_level, 2)
else:
hybird_ref_2d = torch.stack([ref_2d, ref_2d],
1).reshape(bs * 2, len_bev,
num_bev_level, 2)
for lid, layer in enumerate(self.layers):
output = layer(
bev_query,
key,
value,
*args,
bev_pos=bev_pos,
ref_2d=hybird_ref_2d,
ref_3d=ref_3d,
bev_h=bev_h,
bev_w=bev_w,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
reference_points_cam=reference_points_cam,
bev_mask=bev_mask,
prev_bev=prev_bev,
**kwargs)
bev_query = output
if self.return_intermediate:
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
@TRANSFORMER.register_module()
class PerceptionTransformer(BaseModule):
"""Implements the Detr3D transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self,
num_feature_levels=4,
num_cams=6,
two_stage_num_proposals=300,
encoder=None,
decoder=None,
embed_dims=256,
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
can_bus_norm=True,
use_cams_embeds=True,
rotate_center=[100, 100],
**kwargs):
super(PerceptionTransformer, self).__init__(**kwargs)
self.encoder = build_transformer_layer_sequence(encoder)
self.decoder = build_transformer_layer_sequence(decoder)
self.embed_dims = embed_dims
self.num_feature_levels = num_feature_levels
self.num_cams = num_cams
self.rotate_prev_bev = rotate_prev_bev
self.use_shift = use_shift
self.use_can_bus = use_can_bus
self.can_bus_norm = can_bus_norm
self.use_cams_embeds = use_cams_embeds
self.two_stage_num_proposals = two_stage_num_proposals
self.init_layers()
self.rotate_center = rotate_center
def init_layers(self):
"""Initialize layers of the Detr3DTransformer."""
self.level_embeds = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
self.cams_embeds = nn.Parameter(
torch.Tensor(self.num_cams, self.embed_dims))
self.reference_points = nn.Linear(self.embed_dims, 3)
self.can_bus_mlp = nn.Sequential(
nn.Linear(18, self.embed_dims // 2),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dims // 2, self.embed_dims),
nn.ReLU(inplace=True),
)
if self.can_bus_norm:
self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
or isinstance(m, CustomMSDeformableAttention):
m.init_weights()
normal_(self.level_embeds)
normal_(self.cams_embeds)
xavier_init(self.reference_points, distribution='uniform', bias=0.)
xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)
@auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'prev_bev', 'bev_pos'))
def get_bev_features(self,
mlvl_feats,
bev_queries,
bev_h,
bev_w,
grid_length=[0.512, 0.512],
bev_pos=None,
prev_bev=None,
**kwargs):
"""
obtain bev features.
"""
bs = mlvl_feats[0].size(0)
bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
# obtain rotation angle and shift with ego motion
delta_x = torch.stack(
[each['can_bus'][0] for each in kwargs['img_metas']])
delta_y = torch.stack(
[each['can_bus'][1] for each in kwargs['img_metas']])
ego_angle = torch.stack([
each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas']
])
grid_length_y = grid_length[0]
grid_length_x = grid_length[1]
translation_length = torch.sqrt(delta_x**2 + delta_y**2)
translation_angle = torch.atan2(delta_y, delta_x) / np.pi * 180
bev_angle = ego_angle - translation_angle
shift_y = translation_length * \
torch.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
shift_x = translation_length * \
torch.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
if not self.use_shift:
shift_y = shift_y.new_zeros(shift_y.size())
shift_x = shift_x.new_zeros(shift_y.size())
shift = torch.stack([shift_x,
shift_y]).permute(1, 0).to(bev_queries.dtype)
if prev_bev is not None:
if prev_bev.shape[1] == bev_h * bev_w:
prev_bev = prev_bev.permute(1, 0, 2)
if self.rotate_prev_bev:
for i in range(bs):
# num_prev_bev = prev_bev.size(1)
rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
tmp_prev_bev = prev_bev[:, i].reshape(bev_h, bev_w,
-1).permute(2, 0, 1)
tmp_prev_bev = _rotate(
tmp_prev_bev,
rotation_angle,
center=torch.tensor(self.rotate_center))
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
bev_h * bev_w, 1, -1)
prev_bev[:, i] = tmp_prev_bev[:, 0]
# add can bus signals
can_bus = torch.stack([
each['can_bus'] for each in kwargs['img_metas']
]).to(bev_queries.dtype)
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
# fix fp16
can_bus = can_bus.to(bev_queries.dtype)
if self.use_can_bus:
bev_queries = bev_queries + can_bus
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
bs, num_cam, c, h, w = feat.shape
spatial_shape = (h, w)
feat = feat.flatten(3).permute(1, 0, 3, 2)
if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(
feat.dtype)
spatial_shapes.append(spatial_shape)
feat_flatten.append(feat)
feat_flatten = torch.cat(feat_flatten, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=bev_pos.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
feat_flatten = feat_flatten.permute(
0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims)
bev_embed = self.encoder(
bev_queries,
feat_flatten,
feat_flatten,
bev_h=bev_h,
bev_w=bev_w,
bev_pos=bev_pos,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
prev_bev=prev_bev,
shift=shift,
**kwargs)
return bev_embed
@auto_fp16(
apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed',
'prev_bev', 'bev_pos'))
def forward(self,
mlvl_feats,
bev_queries,
object_query_embed,
bev_h,
bev_w,
grid_length=[0.512, 0.512],
bev_pos=None,
reg_branches=None,
cls_branches=None,
prev_bev=None,
attn_mask=None,
**kwargs):
"""Forward function for `Detr3DTransformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, num_cams, embed_dims, h, w].
bev_queries (Tensor): (bev_h*bev_w, c)
bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
object_query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when `with_box_refine` is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- bev_embed: BEV features
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
bev_embed = self.get_bev_features(
mlvl_feats,
bev_queries,
bev_h,
bev_w,
grid_length=grid_length,
bev_pos=bev_pos,
prev_bev=prev_bev,
**kwargs) # bev_embed shape: bs, bev_h*bev_w, embed_dims
bs = mlvl_feats[0].size(0)
query_pos, query = torch.split(
object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
query = query.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
bev_embed = bev_embed.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=bev_embed,
query_pos=query_pos,
reference_points=reference_points,
attn_mask=attn_mask,
reg_branches=reg_branches,
cls_branches=cls_branches,
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
level_start_index=torch.tensor([0], device=query.device),
**kwargs)
inter_references_out = inter_references
return bev_embed, inter_states, init_reference_out, inter_references_out
@POSITIONAL_ENCODING.register_module()
class LearnedPositionalEncoding(BaseModule):
"""Position embedding with learnable embedding weights.
Args:
num_feats (int): The feature dimension for each position
along x-axis or y-axis. The final returned dimension for
each position is 2 times of this value.
row_num_embed (int, optional): The dictionary size of row embeddings.
Default 50.
col_num_embed (int, optional): The dictionary size of col embeddings.
Default 50.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
num_feats,
row_num_embed=50,
col_num_embed=50,
init_cfg=dict(type='Uniform', layer='Embedding')):
super(LearnedPositionalEncoding, self).__init__(init_cfg)
self.row_embed = nn.Embedding(row_num_embed, num_feats)
self.col_embed = nn.Embedding(col_num_embed, num_feats)
self.num_feats = num_feats
self.row_num_embed = row_num_embed
self.col_num_embed = col_num_embed
def forward(self, mask):
"""Forward function for `LearnedPositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
h, w = mask.shape[-2:]
x = torch.arange(w, device=mask.device)
y = torch.arange(h, device=mask.device)
x_embed = self.col_embed(x)
y_embed = self.row_embed(y)
pos = torch.cat(
(x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(
1, w, 1)),
dim=-1).permute(2, 0,
1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
return pos
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(num_feats={self.num_feats}, '
repr_str += f'row_num_embed={self.row_num_embed}, '
repr_str += f'col_num_embed={self.col_num_embed})'
return repr_str