def losses()

in models/UN_EPT.py [0:0]


    def losses(self, seg_logit, pred_mask, pred_direction, seg_label, distance_map, angle_map):
        """Compute segmentation loss."""
        loss = dict()
        
        seg_weight = None

        gt_mask = self.distance_to_mask_label(distance_map, seg_label, return_tensor=True)
        gt_size = gt_mask.shape[2:]
        mask_weights = self.calc_weights(gt_mask, 2)

        pred_direction = F.interpolate(pred_direction, size=gt_size, mode="bilinear", align_corners=True)
        pred_mask = F.interpolate(pred_mask, size=gt_size, mode="bilinear", align_corners=True)
        mask_loss = F.cross_entropy(pred_mask, gt_mask[:,0], weight=mask_weights, ignore_index=-1)
        
        mask_threshold = 0.5
        binary_pred_mask = torch.softmax(pred_mask, dim=1)[:, 1, :, :] > mask_threshold
        
        gt_direction = self.angle_to_direction_label(
            angle_map,
            seg_label_map=seg_label,
            extra_ignore_mask=(binary_pred_mask == 0),
            return_tensor=True
        )

        direction_loss_mask = gt_direction != -1
        direction_weights = self.calc_weights(gt_direction[direction_loss_mask], pred_direction.size(1))
        direction_loss = F.cross_entropy(pred_direction, gt_direction[:,0], weight=direction_weights, ignore_index=-1)

        offset = self._get_offset(pred_mask, pred_direction)
        refine_map = self.shift(seg_logit, offset.permute(0,3,1,2))

        seg_label = seg_label.squeeze(1)

        loss['loss_seg'] = 0.8*self.loss_decode(
            seg_logit,
            seg_label,
            weight=seg_weight,
            ignore_index=self.ignore_index) + 5*mask_loss + 0.6*direction_loss + \
            self.loss_decode(
            refine_map,
            seg_label,
            weight=seg_weight,
            ignore_index=self.ignore_index)
        
        loss['acc_seg'] = accuracy(seg_logit, seg_label)

        return loss