in projects_oss/detr/detr/models/deformable_transformer.py [0:0]
def forward(self, srcs, masks, pos_embeds, query_embed=None):
"""
Args:
srcs: a list of num_levels tensors. Each has shape (N, C, H_l, W_l)
masks: a list of num_levels tensors. Each has shape (N, H_l, W_l)
pos_embeds: a list of num_levels tensors. Each has shape (N, C, H_l, W_l)
query_embed: a tensor has shape (num_queries, C)
"""
assert self.two_stage or query_embed is not None
# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
# src shape (bs, h_l*w_l, c)
src = src.flatten(2).transpose(1, 2)
# mask shape (bs, h_l*w_l)
mask = mask.flatten(1)
# pos_embed shape (bs, h_l*w_l, c)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
# lvl_pos_embed shape (bs, h_l*w_l, c)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)
# src_flatten shape: (bs, K, c) where K = \sum_l H_l * w_l
src_flatten = torch.cat(src_flatten, 1)
# mask_flatten shape: (bs, K)
mask_flatten = torch.cat(mask_flatten, 1)
# lvl_pos_embed_flatten shape: (bs, K, c)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# spatial_shapes shape: (num_levels, 2)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
# level_start_index shape: (num_levels)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
# valid_ratios shape: (bs, num_levels, 2)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
# encoder
# memory shape (bs, K, C) where K = \sum_l H_l * w_l
memory = self.encoder(
src_flatten,
spatial_shapes,
level_start_index,
valid_ratios,
lvl_pos_embed_flatten,
mask_flatten,
)
# prepare input for decoder
bs, _, c = memory.shape
if self.two_stage:
# output_memory shape (bs, K, C)
# output_proposals shape (bs, K, 4)
# output_proposals_valid shape (bs, K, 1)
(
output_memory,
output_proposals,
output_proposals_valid,
) = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
# hack implementation for two-stage Deformable DETR
# shape (bs, K, 1)
enc_outputs_class = self.encoder.class_embed(output_memory)
# fill in -inf foreground logit at invalid positions so that we will never pick
# top-scored proposals at those positions
enc_outputs_class.masked_fill(mask_flatten.unsqueeze(-1), NEG_INF)
enc_outputs_class.masked_fill(~output_proposals_valid, NEG_INF)
# shape (bs, K, 4)
enc_outputs_coord_unact = (
self.encoder.bbox_embed(output_memory) + output_proposals
)
topk = self.two_stage_num_proposals
# topk_proposals: indices of top items. Shape (bs, top_k)
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
# topk_coords_unact shape (bs, top_k, 4)
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
)
topk_coords_unact = topk_coords_unact.detach()
init_reference_out = topk_coords_unact
# shape (bs, top_k, C=512)
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
)
# query_embed shape (bs, top_k, c)
# tgt shape (bs, top_k, c)
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
else:
# query_embed (or tgt) shape: (num_queries, c)
query_embed, tgt = torch.split(query_embed, c, dim=1)
# query_embed shape: (batch_size, num_queries, c)
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
# tgt shape: (batch_size, num_queries, c)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
# init_reference_out shape: (batch_size, num_queries, 2)
init_reference_out = self.reference_points(query_embed)
# decoder
# hs shape: (num_layers, batch_size, num_queries, c)
# inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
hs, inter_references = self.decoder(
tgt,
init_reference_out,
memory,
spatial_shapes,
level_start_index,
valid_ratios,
query_embed,
mask_flatten,
)
inter_references_out = inter_references
if self.two_stage:
return (
hs,
init_reference_out,
inter_references_out,
enc_outputs_class,
enc_outputs_coord_unact,
)
# hs shape: (num_layers, batch_size, num_queries, c)
# init_reference_out shape: (batch_size, num_queries, 2)
# inter_references_out shape: (num_layers, bs, num_queries, num_levels, 2)
return hs, init_reference_out, inter_references_out, None, None