def forward()

in d2go/modeling/subclass.py [0:0]


    def forward(self, images, features, proposals, targets=None):
        """
        Same as StandardROIHeads.forward but add logic for subclass.
        """
        if not self.subclass_on:
            return super().forward(images, features, proposals, targets)

        # --- start copy -------------------------------------------------------
        del images

        if self.training:
            proposals = self.label_and_sample_proposals(proposals, targets)
            # NOTE: `has_gt` = False for negatives and we must manually register `gt_subclasses`,
            #  because custom gt_* fields will not be automatically registered in sampled proposals.
            for pp_per_im in proposals:
                if not pp_per_im.has("gt_subclasses"):
                    background_subcls_idx = 0
                    pp_per_im.gt_subclasses = torch.cuda.LongTensor(
                        len(pp_per_im)
                    ).fill_(background_subcls_idx)
        del targets

        features_list = [features[f] for f in self.in_features]

        box_features = self.box_pooler(
            features_list, [x.proposal_boxes for x in proposals]
        )
        box_features = self.box_head(box_features)
        predictions = self.box_predictor(box_features)
        # --- end copy ---------------------------------------------------------

        # NOTE: don't delete box_features, keep it temporarily
        # del box_features
        box_features = box_features.view(
            box_features.shape[0], np.prod(box_features.shape[1:])
        )
        pred_subclass_logits = self.subclass_head(box_features)

        if self.training:
            losses = self.box_predictor.losses(predictions, proposals)
            # During training the proposals used by the box head are
            # used by the mask, keypoint (and densepose) heads.
            losses.update(self._forward_mask(features, proposals))
            losses.update(self._forward_keypoint(features, proposals))

            # subclass head
            gt_subclasses = cat([p.gt_subclasses for p in proposals], dim=0)
            loss_subclass = F.cross_entropy(
                pred_subclass_logits, gt_subclasses, reduction="mean"
            )
            losses.update({"loss_subclass": loss_subclass})

            return proposals, losses
        else:
            pred_instances, kept_indices = self.box_predictor.inference(
                predictions, proposals
            )
            # During inference cascaded prediction is used: the mask and keypoints
            # heads are only applied to the top scoring box detections.
            pred_instances = self.forward_with_given_boxes(features, pred_instances)

            # subclass head
            probs = F.softmax(pred_subclass_logits, dim=-1)
            for pred_instances_i, kept_indices_i in zip(pred_instances, kept_indices):
                pred_instances_i.pred_subclass_prob = torch.index_select(
                    probs,
                    dim=0,
                    index=kept_indices_i.to(torch.int64),
                )

            if torch.onnx.is_in_onnx_export():
                assert len(pred_instances) == 1
                pred_instances[0].pred_subclass_prob = alias(
                    pred_instances[0].pred_subclass_prob, "subclass_prob_nms"
                )

            return pred_instances, {}