in easycv/models/detection/detectors/dino/deformable_transformer.py [0:0]
def forward(self,
src: Tensor,
pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
key_padding_mask: Tensor,
ref_token_index: Optional[Tensor] = None,
ref_token_coord: Optional[Tensor] = None):
"""
Input:
- src: [bs, sum(hi*wi), 256]
- pos: pos embed for src. [bs, sum(hi*wi), 256]
- spatial_shapes: h,w of each level [num_level, 2]
- level_start_index: [num_level] start point of level in sum(hi*wi).
- valid_ratios: [bs, num_level, 2]
- key_padding_mask: [bs, sum(hi*wi)]
- ref_token_index: bs, nq
- ref_token_coord: bs, nq, 4
Intermedia:
- reference_points: [bs, sum(hi*wi), num_level, 2]
Outpus:
- output: [bs, sum(hi*wi), 256]
"""
if self.two_stage_type in [
'no', 'standard', 'enceachlayer', 'enclayer1'
]:
assert ref_token_index is None
output = src
# preparation and reshape
if self.num_layers > 0:
if self.deformable_encoder:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device)
intermediate_output = []
intermediate_ref = []
if ref_token_index is not None:
out_i = torch.gather(
output, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
# main process
for layer_id, layer in enumerate(self.layers):
# main process
dropflag = False
if self.enc_layer_dropout_prob is not None:
prob = random.random()
if prob < self.enc_layer_dropout_prob[layer_id]:
dropflag = True
if not dropflag:
if self.deformable_encoder:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask)
else:
output = layer(
src=output.transpose(0, 1),
pos=pos.transpose(0, 1),
key_padding_mask=key_padding_mask).transpose(0, 1)
if ((layer_id == 0
and self.two_stage_type in ['enceachlayer', 'enclayer1']) or
(self.two_stage_type
== 'enceachlayer')) and (layer_id != self.num_layers - 1):
output_memory, output_proposals = gen_encoder_output_proposals(
output, key_padding_mask, spatial_shapes)
output_memory = self.enc_norm[layer_id](
self.enc_proj[layer_id](output_memory))
# gather boxes
topk = self.num_queries
enc_outputs_class = self.class_embed[layer_id](output_memory)
ref_token_index = torch.topk(
enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq
ref_token_coord = torch.gather(
output_proposals, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, 4))
output = output_memory
# aux loss
if (layer_id !=
self.num_layers - 1) and ref_token_index is not None:
out_i = torch.gather(
output, 1,
ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
intermediate_output.append(out_i)
intermediate_ref.append(ref_token_coord)
if self.norm is not None:
output = self.norm(output)
if ref_token_index is not None:
intermediate_output = torch.stack(
intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model
intermediate_ref = torch.stack(intermediate_ref)
else:
intermediate_output = intermediate_ref = None
return output, intermediate_output, intermediate_ref