in easycv/models/detection/detectors/dino/deformable_transformer.py [0:0]
def forward(self,
srcs,
masks,
refpoint_embed,
pos_embeds,
tgt,
attn_mask=None):
"""
Input:
- srcs: List of multi features [bs, ci, hi, wi]
- masks: List of multi masks [bs, hi, wi]
- refpoint_embed: [bs, num_dn, 4]. None in infer
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
- tgt: [bs, num_dn, d_model]. None in infer
"""
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask,
pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1) # bs, hw
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(
1, 1, -1)
else:
lvl_pos_embed = pos_embed
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten,
1) # bs, \sum{hxw}, c
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# two stage
enc_topk_proposals = enc_refpoint_embed = None
#########################################################
# Begin Encoder
#########################################################
memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
ref_token_index=enc_topk_proposals, # bs, nq
ref_token_coord=enc_refpoint_embed, # bs, nq, 4
)
if self.multi_encoder_memory:
memory = self.memory_reduce(torch.cat([src_flatten, memory], -1))
#########################################################
# End Encoder
# - memory: bs, \sum{hw}, c
# - mask_flatten: bs, \sum{hw}
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_learn_wh:
input_hw = self.two_stage_wh_embedding.weight[0]
else:
input_hw = None
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes, input_hw)
output_memory = self.enc_output_norm(
self.enc_output(output_memory))
if self.two_stage_pat_embed > 0:
bs, nhw, _ = output_memory.shape
# output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256
output_memory = output_memory.repeat(1,
self.two_stage_pat_embed,
1)
_pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0)
output_memory = output_memory + _pats
output_proposals = output_proposals.repeat(
1, self.two_stage_pat_embed, 1)
if self.two_stage_add_query_num > 0:
assert refpoint_embed is not None
output_memory = torch.cat((output_memory, tgt), dim=1)
output_proposals = torch.cat(
(output_proposals, refpoint_embed), dim=1)
enc_outputs_class_unselected = self.enc_out_class_embed(
output_memory)
enc_outputs_coord_unselected = self.enc_out_bbox_embed(
output_memory
) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
topk_proposals = torch.topk(
enc_outputs_class_unselected.max(-1)[0], topk,
dim=1)[1] # bs, nq
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1,
4)).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
if self.embed_init_tgt:
tgt_ = self.tgt_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, d_model
else:
tgt_ = tgt_undetach.detach()
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
elif self.two_stage_type == 'no':
tgt_ = self.tgt_embed.weight[:,
None, :].repeat(1, bs, 1).transpose(
0, 1) # nq, bs, d_model
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(
1, bs, 1).transpose(0, 1) # nq, bs, 4
if refpoint_embed is not None:
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_],
dim=1)
tgt = torch.cat([tgt, tgt_], dim=1)
else:
refpoint_embed, tgt = refpoint_embed_, tgt_
if self.num_patterns > 0:
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
self.num_queries, 1) # 1, n_q*n_pat, d_model
tgt = tgt_embed + tgt_pat
init_box_proposal = refpoint_embed_.sigmoid()
else:
raise NotImplementedError('unknown two_stage_type {}'.format(
self.two_stage_type))
#########################################################
# End preparing tgt
# - tgt: bs, NQ, d_model
# - refpoint_embed(unsigmoid): bs, NQ, d_model
#########################################################
#########################################################
# Begin Decoder
#########################################################
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=attn_mask)
#########################################################
# End Decoder
# hs: n_dec, bs, nq, d_model
# references: n_dec+1, bs, nq, query_dim
#########################################################
#########################################################
# Begin postprocess
#########################################################
if self.two_stage_type == 'standard':
if self.two_stage_keep_all_tokens:
hs_enc = output_memory.unsqueeze(0)
ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
init_box_proposal = output_proposals
else:
hs_enc = tgt_undetach.unsqueeze(0)
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
else:
hs_enc = ref_enc = None
#########################################################
# End postprocess
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
#########################################################
return hs, references, hs_enc, ref_enc, init_box_proposal