def forward()

in seamseg/models/panoptic.py [0:0]


    def forward(self, img, msk=None, cat=None, iscrowd=None, bbx=None, do_loss=False, do_prediction=True):
        # Pad the input images
        img, valid_size = pad_packed_images(img)
        img_size = img.shape[-2:]

        # Convert ground truth to the internal format
        if do_loss:
            cat, iscrowd, bbx, ids, sem = self._prepare_inputs(msk, cat, iscrowd, bbx)

        # Run network body
        x = self.body(img)

        # RPN part
        if do_loss:
            obj_loss, bbx_loss, proposals = self.rpn_algo.training(
                self.rpn_head, x, bbx, iscrowd, valid_size, training=self.training, do_inference=True)
        elif do_prediction:
            proposals = self.rpn_algo.inference(self.rpn_head, x, valid_size, self.training)
            obj_loss, bbx_loss = None, None
        else:
            obj_loss, bbx_loss, proposals = None, None, None

        # ROI part
        if do_loss:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = self.instance_seg_algo.training(
                self.roi_head, x, proposals, bbx, cat, iscrowd, ids, msk, img_size)
        else:
            roi_cls_loss, roi_bbx_loss, roi_msk_loss = None, None, None
        if do_prediction:
            bbx_pred, cls_pred, obj_pred, msk_pred = self.instance_seg_algo.inference(
                self.roi_head, x, proposals, valid_size, img_size)
        else:
            bbx_pred, cls_pred, obj_pred, msk_pred = None, None, None, None

        # Segmentation part
        if do_loss:
            sem_loss, conf_mat, sem_pred = self.semantic_seg_algo.training(self.sem_head, x, sem, valid_size, img_size)
        elif do_prediction:
            sem_pred = self.semantic_seg_algo.inference(self.sem_head, x, valid_size, img_size)
            sem_loss, conf_mat = None, None
        else:
            sem_loss, conf_mat, sem_pred = None, None, None

        # Prepare outputs
        loss = OrderedDict([
            ("obj_loss", obj_loss),
            ("bbx_loss", bbx_loss),
            ("roi_cls_loss", roi_cls_loss),
            ("roi_bbx_loss", roi_bbx_loss),
            ("roi_msk_loss", roi_msk_loss),
            ("sem_loss", sem_loss)
        ])
        pred = OrderedDict([
            ("bbx_pred", bbx_pred),
            ("cls_pred", cls_pred),
            ("obj_pred", obj_pred),
            ("msk_pred", msk_pred),
            ("sem_pred", sem_pred)
        ])
        conf = OrderedDict([
            ("sem_conf", conf_mat)
        ])
        return loss, pred, conf