def forward()

in projects_oss/detr/detr/models/deformable_detr.py [0:0]


    def forward(self, samples: NestedTensor):
        """The forward expects a NestedTensor, which consists of:
           - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
           - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

        It returns a dict with the following elements:
           - "pred_logits": the classification logits (including no-object) for all queries.
                            Shape= [batch_size x num_queries x (num_classes + 1)]
           - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                           (center_x, center_y, height, width). These values are normalized in [0, 1],
                           relative to the size of each individual image (disregarding possible padding).
                           See PostProcess for information on how to retrieve the unnormalized bounding box.
           - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                            dictionnaries containing the two above keys for each decoder layer.
        """
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        # features is a list of num_levels NestedTensor.
        # pos is a list of num_levels tensors. Each one has shape (B, H_l, W_l).
        features, pos = self.backbone(samples)
        # srcs is a list of num_levels tensor. Each one has shape (B, C, H_l, W_l)
        srcs = []
        # masks is a list of num_levels tensor. Each one has shape (B, H_l, W_l)
        masks = []
        for l, feat in enumerate(features):
            # src shape: (N, C, H_l, W_l)
            # mask shape: (N, H_l, W_l)
            src, mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None

        if self.num_feature_levels > len(srcs):
            N, C, H, W = samples.tensor.size()
            sample_mask = torch.ones((N, H, W), dtype=torch.bool, device=src.device)
            for idx in range(N):
                image_size = samples.image_sizes[idx]
                h, w = image_size
                sample_mask[idx, :h, :w] = False
            # sample_mask shape (1, N, H, W)
            sample_mask = sample_mask[None].float()

            _len_srcs = len(srcs)
            for l in range(_len_srcs, self.num_feature_levels):
                if l == _len_srcs:
                    src = self.input_proj[l](features[-1].tensors)
                else:
                    src = self.input_proj[l](srcs[-1])
                b, _, h, w = src.size()
                # mask shape (batch_size, h_l, w_l)
                mask = F.interpolate(sample_mask, size=src.shape[-2:]).to(torch.bool)[0]
                pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
                srcs.append(src)
                masks.append(mask)
                pos.append(pos_l)

        query_embeds = None
        if not self.two_stage:
            # shape (num_queries, hidden_dim*2)
            query_embeds = self.query_embed.weight

        # hs shape: (num_layers, batch_size, num_queries, c)
        # init_reference shape: (batch_size, num_queries, 2)
        # inter_references shape: (num_layers, bs, num_queries, num_levels, 2)
        (
            hs,
            init_reference,
            inter_references,
            enc_outputs_class,
            enc_outputs_coord_unact,
        ) = self.transformer(srcs, masks, pos, query_embeds)

        outputs_classes = []
        outputs_coords = []
        for lvl in range(hs.shape[0]):
            # reference shape: (num_queries, 2)
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            # shape (batch_size, num_queries, num_classes)
            outputs_class = self.class_embed[lvl](hs[lvl])
            # shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[..., :2] += reference
            # shape (batch_size, num_queries, 4). 4-tuple (cx, cy, w, h)
            outputs_coord = tmp.sigmoid()

            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)
        # shape (num_levels, batch_size, num_queries, num_classes)
        outputs_class = torch.stack(outputs_classes)
        # shape (num_levels, batch_size, num_queries, 4)
        outputs_coord = torch.stack(outputs_coords)

        out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
        if self.aux_loss:
            out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)

        if self.two_stage:
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            out["enc_outputs"] = {
                "pred_logits": enc_outputs_class,
                "pred_boxes": enc_outputs_coord,
            }
        return out