easycv/models/detection/detectors/dino/deformable_transformer.py (927 lines of code) (raw):
# Copyright (c) 2022 IDEA. All Rights Reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
import math
import random
from typing import Optional
import torch
from torch import Tensor, nn
from easycv.framework.errors import NotImplementedError
from easycv.models.builder import NECKS
from easycv.models.detection.utils import (gen_encoder_output_proposals,
gen_sineembed_for_position,
inverse_sigmoid)
from easycv.models.utils import MLP, _get_activation_fn, _get_clones
@NECKS.register_module
class DeformableTransformer(nn.Module):
def __init__(
self,
d_model=256,
nhead=8,
num_queries=300,
num_encoder_layers=6,
num_unicoder_layers=0,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.0,
activation='relu',
normalize_before=False,
return_intermediate_dec=True,
query_dim=4,
num_patterns=0,
modulate_hw_attn=False,
# for deformable encoder
multi_encoder_memory=False,
deformable_encoder=True,
deformable_decoder=True,
num_feature_levels=1,
enc_n_points=4,
dec_n_points=4,
# init query
decoder_query_perturber=None,
add_channel_attention=False,
random_refpoints_xy=False,
# two stage
two_stage_type='no', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
two_stage_pat_embed=0,
two_stage_add_query_num=0,
two_stage_learn_wh=False,
two_stage_keep_all_tokens=False,
# evo of #anchors
dec_layer_number=None,
rm_dec_query_scale=True,
rm_self_attn_layers=None,
key_aware_type=None,
# layer share
layer_share_type=None,
# for detach
rm_detach=None,
decoder_sa_type='sa',
module_seq=['sa', 'ca', 'ffn'],
# for dn
embed_init_tgt=False,
use_detached_boxes_dec_out=False,
):
super().__init__()
self.num_feature_levels = num_feature_levels
self.num_encoder_layers = num_encoder_layers
self.num_unicoder_layers = num_unicoder_layers
self.num_decoder_layers = num_decoder_layers
self.deformable_encoder = deformable_encoder
self.deformable_decoder = deformable_decoder
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
self.num_queries = num_queries
self.random_refpoints_xy = random_refpoints_xy
self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
assert query_dim == 4
if num_feature_levels > 1:
assert deformable_encoder, 'only support deformable_encoder for num_feature_levels > 1'
assert layer_share_type in [None, 'encoder', 'decoder', 'both']
if layer_share_type in ['encoder', 'both']:
enc_layer_share = True
else:
enc_layer_share = False
if layer_share_type in ['decoder', 'both']:
dec_layer_share = True
else:
dec_layer_share = False
assert layer_share_type is None
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
# choose encoder layer type
if deformable_encoder:
encoder_layer = DeformableTransformerEncoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
enc_n_points,
add_channel_attention=add_channel_attention)
else:
raise NotImplementedError
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
d_model=d_model,
num_queries=num_queries,
deformable_encoder=deformable_encoder,
enc_layer_share=enc_layer_share,
two_stage_type=two_stage_type)
self.multi_encoder_memory = multi_encoder_memory
if self.multi_encoder_memory:
self.memory_reduce = nn.Linear(d_model * 2, d_model)
# choose decoder layer type
if deformable_decoder:
decoder_layer = DeformableTransformerDecoderLayer(
d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
dec_n_points,
key_aware_type=key_aware_type,
decoder_sa_type=decoder_sa_type,
module_seq=module_seq)
else:
raise NotImplementedError
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model,
query_dim=query_dim,
modulate_hw_attn=modulate_hw_attn,
num_feature_levels=num_feature_levels,
deformable_decoder=deformable_decoder,
decoder_query_perturber=decoder_query_perturber,
dec_layer_number=dec_layer_number,
rm_dec_query_scale=rm_dec_query_scale,
dec_layer_share=dec_layer_share,
use_detached_boxes_dec_out=use_detached_boxes_dec_out)
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries # useful for single stage model only
self.num_patterns = num_patterns
if not isinstance(num_patterns, int):
Warning('num_patterns should be int but {}'.format(
type(num_patterns)))
self.num_patterns = 0
if num_feature_levels > 1:
if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(
torch.Tensor(num_feature_levels, d_model))
else:
self.level_embed = None
self.embed_init_tgt = embed_init_tgt
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type
== 'no'):
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
else:
self.tgt_embed = None
# for two stage
self.two_stage_type = two_stage_type
self.two_stage_pat_embed = two_stage_pat_embed
self.two_stage_add_query_num = two_stage_add_query_num
self.two_stage_learn_wh = two_stage_learn_wh
assert two_stage_type in [
'no', 'standard'
], 'unknown param {} of two_stage_type'.format(two_stage_type)
if two_stage_type == 'standard':
# anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
if two_stage_pat_embed > 0:
self.pat_embed_for_2stage = nn.Parameter(
torch.Tensor(two_stage_pat_embed, d_model))
nn.init.normal_(self.pat_embed_for_2stage)
if two_stage_add_query_num > 0:
self.tgt_embed = nn.Embedding(self.two_stage_add_query_num,
d_model)
if two_stage_learn_wh:
self.two_stage_wh_embedding = nn.Embedding(1, 2)
else:
self.two_stage_wh_embedding = None
if two_stage_type == 'no':
self.init_ref_points(num_queries)
self.enc_out_class_embed = None
self.enc_out_bbox_embed = None
self.enc_out_center_embed = None
self.enc_out_iou_embed = None
# evolution of anchors
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
if self.two_stage_type != 'no' or num_patterns == 0:
assert dec_layer_number[
0] == num_queries, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})'
else:
assert dec_layer_number[
0] == num_queries * num_patterns, f'dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})'
self._reset_parameters()
self.rm_self_attn_layers = rm_self_attn_layers
if rm_self_attn_layers is not None:
print('Removing the self-attn in {} decoder layers'.format(
rm_self_attn_layers))
for lid, dec_layer in enumerate(self.decoder.layers):
if lid in rm_self_attn_layers:
dec_layer.rm_self_attn_modules()
self.rm_detach = rm_detach
if self.rm_detach:
assert isinstance(rm_detach, list)
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
self.decoder.rm_detach = rm_detach
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
from easycv.thirdparty.deformable_attention.modules import MSDeformAttn
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
if self.two_stage_learn_wh:
nn.init.constant_(self.two_stage_wh_embedding.weight,
math.log(0.05 / (1 - 0.05)))
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
if self.random_refpoints_xy:
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False
def forward(self,
srcs,
masks,
refpoint_embed,
pos_embeds,
tgt,
attn_mask=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask,
pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(
1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten,
1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
ref_token_index=enc_topk_proposals, # bs, nq
ref_token_coord=enc_refpoint_embed, # bs, nq, 4
)
if self.multi_encoder_memory:
memory = self.memory_reduce(torch.cat([src_flatten, memory], -1))
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_learn_wh:
input_hw = self.two_stage_wh_embedding.weight[0]
else:
input_hw = None
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes, input_hw)
output_memory = self.enc_output_norm(
self.enc_output(output_memory))
if self.two_stage_pat_embed > 0:
bs, nhw, _ = output_memory.shape
# output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256
output_memory = output_memory.repeat(1,
self.two_stage_pat_embed,
1)
_pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0)
output_memory = output_memory + _pats
output_proposals = output_proposals.repeat(
1, self.two_stage_pat_embed, 1)
if self.two_stage_add_query_num > 0:
assert refpoint_embed is not None
output_memory = torch.cat((output_memory, tgt), dim=1)
output_proposals = torch.cat(
(output_proposals, refpoint_embed), dim=1)
enc_outputs_class_unselected = self.enc_out_class_embed(
output_memory)
enc_outputs_coord_unselected = self.enc_out_bbox_embed(
output_memory
) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(
enc_outputs_class_unselected.max(-1)[0], topk,
dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1,
4)).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
if self.embed_init_tgt:
tgt_ = self.tgt_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == 'no':
tgt_ = self.tgt_embed.weight[:,
None, :].repeat(1, bs, 1).transpose(
0, 1) # nq, bs, d_model
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError('unknown two_stage_type {}'.format(
self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_keep_all_tokens:
hs_enc = output_memory.unsqueeze(0)
ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
init_box_proposal = output_proposals
else:
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal
# hs: (n_dec, bs, nq, d_model)
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
# ref_enc: sigmoid coordinates. \
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layer,
num_layers,
norm=None,
d_model=256,
num_queries=300,
deformable_encoder=False,
enc_layer_share=False,
enc_layer_dropout_prob=None,
two_stage_type='no', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
):
super().__init__()
# prepare layers
if num_layers > 0:
self.layers = _get_clones(
encoder_layer, num_layers, layer_share=enc_layer_share)
else:
self.layers = []
del encoder_layer
self.query_scale = None
self.num_queries = num_queries
self.deformable_encoder = deformable_encoder
self.num_layers = num_layers
self.norm = norm
self.d_model = d_model
self.enc_layer_dropout_prob = enc_layer_dropout_prob
if enc_layer_dropout_prob is not None:
assert isinstance(enc_layer_dropout_prob, list)
assert len(enc_layer_dropout_prob) == num_layers
for i in enc_layer_dropout_prob:
assert 0.0 <= i <= 1.0
self.two_stage_type = two_stage_type
if two_stage_type in ['enceachlayer', 'enclayer1']:
_proj_layer = nn.Linear(d_model, d_model)
_norm_layer = nn.LayerNorm(d_model)
if two_stage_type == 'enclayer1':
self.enc_norm = nn.ModuleList([_norm_layer])
self.enc_proj = nn.ModuleList([_proj_layer])
else:
self.enc_norm = nn.ModuleList([
copy.deepcopy(_norm_layer) for i in range(num_layers - 1)
])
self.enc_proj = nn.ModuleList([
copy.deepcopy(_proj_layer) for i in range(num_layers - 1)
])
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(
0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(self,
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
ref_token_index: Optional[Tensor] = None,
ref_token_coord: Optional[Tensor] = None):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- ref_token_index: bs, nq
- ref_token_coord: bs, nq, 4
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
if self.two_stage_type in [
'no', 'standard', 'enceachlayer', 'enclayer1'
]:
assert ref_token_index is None
output = src
# preparation and reshape
if self.num_layers > 0:
if self.deformable_encoder:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device)
intermediate_output = []
intermediate_ref = []
if ref_token_index is not None:
out_i = torch.gather(
output, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
# main process
for layer_id, layer in enumerate(self.layers):
# main process
dropflag = False
if self.enc_layer_dropout_prob is not None:
prob = random.random()
if prob < self.enc_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
if self.deformable_encoder:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask)
else:
output = layer(
src=output.transpose(0, 1),
pos=pos.transpose(0, 1),
key_padding_mask=key_padding_mask).transpose(0, 1)
if ((layer_id == 0
and self.two_stage_type in ['enceachlayer', 'enclayer1']) or
(self.two_stage_type
== 'enceachlayer')) and (layer_id != self.num_layers - 1):
output_memory, output_proposals = gen_encoder_output_proposals(
output, key_padding_mask, spatial_shapes)
output_memory = self.enc_norm[layer_id](
self.enc_proj[layer_id](output_memory))
# gather boxes
topk = self.num_queries
enc_outputs_class = self.class_embed[layer_id](output_memory)
ref_token_index = torch.topk(
enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq
ref_token_coord = torch.gather(
output_proposals, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, 4))
output = output_memory
# aux loss
if (layer_id !=
self.num_layers - 1) and ref_token_index is not None:
out_i = torch.gather(
output, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
if self.norm is not None:
output = self.norm(output)
if ref_token_index is not None:
intermediate_output = torch.stack(
intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model
intermediate_ref = torch.stack(intermediate_ref)
else:
intermediate_output = intermediate_ref = None
return output, intermediate_output, intermediate_ref
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False,
d_model=256,
query_dim=4,
modulate_hw_attn=False,
num_feature_levels=1,
deformable_decoder=False,
decoder_query_perturber=None,
dec_layer_number=None, # number of queries each layer in decoder
rm_dec_query_scale=False,
dec_layer_share=False,
dec_layer_dropout_prob=None,
use_detached_boxes_dec_out=False):
super().__init__()
if num_layers > 0:
self.layers = _get_clones(
decoder_layer, num_layers, layer_share=dec_layer_share)
else:
self.layers = []
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate, 'support return_intermediate only'
self.query_dim = query_dim
assert query_dim in [
2, 4
], 'query_dim should be 2/4 but {}'.format(query_dim)
self.num_feature_levels = num_feature_levels
self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model,
2)
if not deformable_decoder:
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
else:
self.query_pos_sine_scale = None
if rm_dec_query_scale:
self.query_scale = None
else:
raise NotImplementedError
self.query_scale = MLP(d_model, d_model, d_model, 2)
self.bbox_embed = None
self.class_embed = None
self.center_embed = None
self.iou_embed = None
self.d_model = d_model
self.modulate_hw_attn = modulate_hw_attn
self.deformable_decoder = deformable_decoder
if not deformable_decoder and modulate_hw_attn:
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
else:
self.ref_anchor_head = None
self.decoder_query_perturber = decoder_query_perturber
self.box_pred_damping = None
self.dec_layer_number = dec_layer_number
if dec_layer_number is not None:
assert isinstance(dec_layer_number, list)
assert len(dec_layer_number) == num_layers
self.dec_layer_dropout_prob = dec_layer_dropout_prob
if dec_layer_dropout_prob is not None:
assert isinstance(dec_layer_dropout_prob, list)
assert len(dec_layer_dropout_prob) == num_layers
for i in dec_layer_dropout_prob:
assert 0.0 <= i <= 1.0
self.rm_detach = None
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
# for memory
level_start_index: Optional[Tensor] = None, # num_levels
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
valid_ratios: Optional[Tensor] = None,
):
"""
Input:
- tgt: nq, bs, d_model
- memory: hw, bs, d_model
- pos: hw, bs, d_model
- refpoints_unsigmoid: nq, bs, 2/4
- valid_ratios/spatial_shapes: bs, nlevel, 2
"""
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers):
# preprocess ref points
if self.training and self.decoder_query_perturber is not None and layer_id != 0:
reference_points = self.decoder_query_perturber(
reference_points)
if self.deformable_decoder:
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] * torch.cat(
[valid_ratios, valid_ratios],
-1)[None, :] # nq, bs, nlevel, 4
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :,
None] * valid_ratios[
None, :]
query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :]) # nq, bs, 256*2
else:
query_sine_embed = gen_sineembed_for_position(
reference_points) # nq, bs, 256*2
reference_points_input = None
# conditional query
raw_query_pos = self.ref_point_head(
query_sine_embed) # nq, bs, 256
pos_scale = self.query_scale(
output) if self.query_scale is not None else 1
query_pos = pos_scale * raw_query_pos
if not self.deformable_decoder:
query_sine_embed = query_sine_embed[
..., :self.d_model] * self.query_pos_sine_scale(output)
# modulated HW attentions
if not self.deformable_decoder and self.modulate_hw_attn:
refHW_cond = self.ref_anchor_head(
output).sigmoid() # nq, bs, 2
query_sine_embed[..., self.d_model // 2:] *= (
refHW_cond[..., 0] /
reference_points[..., 2]).unsqueeze(-1)
query_sine_embed[..., :self.d_model //
2] *= (refHW_cond[..., 1] /
reference_points[..., 3]).unsqueeze(-1)
# main process
dropflag = False
if self.dec_layer_dropout_prob is not None:
prob = random.random()
if prob < self.dec_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
output = layer(
tgt=output,
tgt_query_pos=query_pos,
tgt_query_sine_embed=query_sine_embed,
tgt_key_padding_mask=tgt_key_padding_mask,
tgt_reference_points=reference_points_input,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
memory_level_start_index=level_start_index,
memory_spatial_shapes=spatial_shapes,
memory_pos=pos,
self_attn_mask=tgt_mask,
cross_attn_mask=memory_mask)
# iter update
if self.bbox_embed is not None:
reference_before_sigmoid = inverse_sigmoid(reference_points)
delta_unsig = self.bbox_embed[layer_id](output)
outputs_unsig = delta_unsig + reference_before_sigmoid
new_reference_points = outputs_unsig.sigmoid()
# select # ref points
if self.dec_layer_number is not None and layer_id != self.num_layers - 1:
nq_now = new_reference_points.shape[0]
select_number = self.dec_layer_number[layer_id + 1]
if nq_now != select_number:
class_unselected = self.class_embed[layer_id](
output) # nq, bs, 91
topk_proposals = torch.topk(
class_unselected.max(-1)[0], select_number,
dim=0)[1] # new_nq, bs
new_reference_points = torch.gather(
new_reference_points, 0,
topk_proposals.unsqueeze(-1).repeat(
1, 1, 4)) # unsigmoid
if self.rm_detach and 'dec' in self.rm_detach:
reference_points = new_reference_points
else:
reference_points = new_reference_points.detach()
if self.use_detached_boxes_dec_out:
ref_points.append(reference_points)
else:
ref_points.append(new_reference_points)
intermediate.append(self.norm(output))
if self.dec_layer_number is not None and layer_id != self.num_layers - 1:
if nq_now != select_number:
output = torch.gather(output, 0,
topk_proposals.unsqueeze(-1).repeat(
1, 1, self.d_model)) # unsigmoid
return [[itm_out.transpose(0, 1) for itm_out in intermediate],
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]]
class DeformableTransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation='relu',
n_levels=4,
n_heads=8,
n_points=4,
add_channel_attention=False,
):
super().__init__()
# self attention
from easycv.thirdparty.deformable_attention.modules import MSDeformAttn
self.self_attn = MSDeformAttn(
d_model, n_levels, n_heads, n_points, im2col_step=64)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# channel attention
self.add_channel_attention = add_channel_attention
if add_channel_attention:
self.activ_channel = _get_activation_fn('dyrelu')
self.norm_channel = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask=None):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, pos), reference_points, src,
spatial_shapes, level_start_index, key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
# channel attn
if self.add_channel_attention:
src = self.norm_channel(src + self.activ_channel(src))
return src
class DeformableTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model=256,
d_ffn=1024,
dropout=0.1,
activation='relu',
n_levels=4,
n_heads=8,
n_points=4,
key_aware_type=None,
decoder_sa_type='ca',
module_seq=['sa', 'ca', 'ffn'],
):
super().__init__()
self.module_seq = module_seq
assert sorted(module_seq) == ['ca', 'ffn', 'sa']
# cross attention
from easycv.thirdparty.deformable_attention.modules import MSDeformAttn
self.cross_attn = MSDeformAttn(
d_model, n_levels, n_heads, n_points, im2col_step=64)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# self attention
self.self_attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = _get_activation_fn(activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
self.key_aware_type = key_aware_type
self.key_aware_proj = None
self.decoder_sa_type = decoder_sa_type
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
if decoder_sa_type == 'ca_content':
from easycv.thirdparty.deformable_attention.modules import MSDeformAttn
self.self_attn = MSDeformAttn(
d_model, n_levels, n_heads, n_points, im2col_step=64)
def rm_self_attn_modules(self):
self.self_attn = None
self.dropout2 = None
self.norm2 = None
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_sa(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[
Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[
Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[
Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[
Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[
Tensor] = None, # mask used for cross-attention
):
# self attention
if self.self_attn is not None:
if self.decoder_sa_type == 'sa':
q = k = self.with_pos_embed(tgt, tgt_query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
elif self.decoder_sa_type == 'ca_label':
bs = tgt.shape[1]
k = v = self.label_embedding.weight[:,
None, :].repeat(1, bs, 1)
tgt2 = self.self_attn(tgt, k, v, attn_mask=self_attn_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
elif self.decoder_sa_type == 'ca_content':
tgt2 = self.self_attn(
self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
tgt_reference_points.transpose(0, 1).contiguous(),
memory.transpose(0, 1), memory_spatial_shapes,
memory_level_start_index,
memory_key_padding_mask).transpose(0, 1)
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
else:
raise NotImplementedError('Unknown decoder_sa_type {}'.format(
self.decoder_sa_type))
return tgt
def forward_ca(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[
Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[
Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[
Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[
Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[
Tensor] = None, # mask used for cross-attention
):
# cross attention
if self.key_aware_type is not None:
if self.key_aware_type == 'mean':
tgt = tgt + memory.mean(0, keepdim=True)
elif self.key_aware_type == 'proj_mean':
tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
else:
raise NotImplementedError('Unknown key_aware_type: {}'.format(
self.key_aware_type))
tgt2 = self.cross_attn(
self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
tgt_reference_points.transpose(0, 1).contiguous(),
memory.transpose(0, 1), memory_spatial_shapes,
memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
return tgt
def forward(
self,
# for tgt
tgt: Optional[Tensor], # nq, bs, d_model
tgt_query_pos: Optional[
Tensor] = None, # pos for query. MLP(Sine(pos))
tgt_query_sine_embed: Optional[
Tensor] = None, # pos for query. Sine(pos)
tgt_key_padding_mask: Optional[Tensor] = None,
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
# for memory
memory: Optional[Tensor] = None, # hw, bs, d_model
memory_key_padding_mask: Optional[Tensor] = None,
memory_level_start_index: Optional[Tensor] = None, # num_levels
memory_spatial_shapes: Optional[
Tensor] = None, # bs, num_levels, 2
memory_pos: Optional[Tensor] = None, # pos for memory
# sa
self_attn_mask: Optional[
Tensor] = None, # mask used for self-attention
cross_attn_mask: Optional[
Tensor] = None, # mask used for cross-attention
):
for funcname in self.module_seq:
if funcname == 'ffn':
tgt = self.forward_ffn(tgt)
elif funcname == 'ca':
tgt = self.forward_ca(tgt, tgt_query_pos, tgt_query_sine_embed,
tgt_key_padding_mask,
tgt_reference_points, memory,
memory_key_padding_mask,
memory_level_start_index,
memory_spatial_shapes, memory_pos,
self_attn_mask, cross_attn_mask)
elif funcname == 'sa':
tgt = self.forward_sa(tgt, tgt_query_pos, tgt_query_sine_embed,
tgt_key_padding_mask,
tgt_reference_points, memory,
memory_key_padding_mask,
memory_level_start_index,
memory_spatial_shapes, memory_pos,
self_attn_mask, cross_attn_mask)
else:
raise ValueError('unknown funcname {}'.format(funcname))
return tgt