def forward()

in projects_oss/detr/detr/models/deformable_transformer.py [0:0]


    def forward(self, srcs, masks, pos_embeds, query_embed=None):
        """
        Args:
            srcs: a list of num_levels tensors. Each has shape (N, C, H_l, W_l)
            masks: a list of num_levels tensors. Each has shape (N, H_l, W_l)
            pos_embeds: a list of num_levels tensors. Each has shape (N, C, H_l, W_l)
            query_embed: a tensor has shape (num_queries, C)
        """
        assert self.two_stage or query_embed is not None

        # prepare input for encoder
        src_flatten = []
        mask_flatten = []
        lvl_pos_embed_flatten = []
        spatial_shapes = []
        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
            bs, c, h, w = src.shape
            spatial_shape = (h, w)
            spatial_shapes.append(spatial_shape)
            # src shape (bs, h_l*w_l, c)
            src = src.flatten(2).transpose(1, 2)
            # mask shape (bs, h_l*w_l)
            mask = mask.flatten(1)
            # pos_embed shape (bs, h_l*w_l, c)
            pos_embed = pos_embed.flatten(2).transpose(1, 2)
            # lvl_pos_embed shape (bs, h_l*w_l, c)
            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
            lvl_pos_embed_flatten.append(lvl_pos_embed)
            src_flatten.append(src)
            mask_flatten.append(mask)
        # src_flatten shape: (bs, K, c) where K = \sum_l H_l * w_l
        src_flatten = torch.cat(src_flatten, 1)
        # mask_flatten shape: (bs, K)
        mask_flatten = torch.cat(mask_flatten, 1)
        # lvl_pos_embed_flatten shape: (bs, K, c)
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
        # spatial_shapes shape: (num_levels, 2)
        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=src_flatten.device
        )
        # level_start_index shape: (num_levels)
        level_start_index = torch.cat(
            (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
        )
        # valid_ratios shape: (bs, num_levels, 2)
        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
        # encoder
        # memory shape (bs, K, C) where K = \sum_l H_l * w_l
        memory = self.encoder(
            src_flatten,
            spatial_shapes,
            level_start_index,
            valid_ratios,
            lvl_pos_embed_flatten,
            mask_flatten,
        )

        # prepare input for decoder
        bs, _, c = memory.shape
        if self.two_stage:
            # output_memory shape (bs, K, C)
            # output_proposals shape (bs, K, 4)
            # output_proposals_valid shape (bs, K, 1)
            (
                output_memory,
                output_proposals,
                output_proposals_valid,
            ) = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
            # hack implementation for two-stage Deformable DETR
            # shape (bs, K, 1)
            enc_outputs_class = self.encoder.class_embed(output_memory)
            # fill in -inf foreground logit at invalid positions so that we will never pick
            # top-scored proposals at those positions
            enc_outputs_class.masked_fill(mask_flatten.unsqueeze(-1), NEG_INF)
            enc_outputs_class.masked_fill(~output_proposals_valid, NEG_INF)
            # shape (bs, K, 4)
            enc_outputs_coord_unact = (
                self.encoder.bbox_embed(output_memory) + output_proposals
            )
            topk = self.two_stage_num_proposals
            # topk_proposals: indices of top items. Shape (bs, top_k)
            topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]

            # topk_coords_unact shape (bs, top_k, 4)
            topk_coords_unact = torch.gather(
                enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
            )
            topk_coords_unact = topk_coords_unact.detach()

            init_reference_out = topk_coords_unact
            # shape (bs, top_k, C=512)
            pos_trans_out = self.pos_trans_norm(
                self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
            )
            # query_embed shape (bs, top_k, c)
            # tgt shape (bs, top_k, c)
            query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
        else:
            # query_embed (or tgt) shape: (num_queries, c)
            query_embed, tgt = torch.split(query_embed, c, dim=1)
            # query_embed shape: (batch_size, num_queries, c)
            query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
            # tgt shape: (batch_size, num_queries, c)
            tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
            # init_reference_out shape: (batch_size, num_queries, 2)
            init_reference_out = self.reference_points(query_embed)

        # decoder
        # hs shape: (num_layers, batch_size, num_queries, c)
        # inter_references shape: (num_layers, batch_size, num_queries, num_levels, 2)
        hs, inter_references = self.decoder(
            tgt,
            init_reference_out,
            memory,
            spatial_shapes,
            level_start_index,
            valid_ratios,
            query_embed,
            mask_flatten,
        )

        inter_references_out = inter_references
        if self.two_stage:
            return (
                hs,
                init_reference_out,
                inter_references_out,
                enc_outputs_class,
                enc_outputs_coord_unact,
            )
        # hs shape: (num_layers, batch_size, num_queries, c)
        # init_reference_out shape: (batch_size, num_queries, 2)
        # inter_references_out shape: (num_layers, bs, num_queries, num_levels, 2)
        return hs, init_reference_out, inter_references_out, None, None