in models/UN_EPT.py [0:0]
def losses(self, seg_logit, pred_mask, pred_direction, seg_label, distance_map, angle_map):
"""Compute segmentation loss."""
loss = dict()
seg_weight = None
gt_mask = self.distance_to_mask_label(distance_map, seg_label, return_tensor=True)
gt_size = gt_mask.shape[2:]
mask_weights = self.calc_weights(gt_mask, 2)
pred_direction = F.interpolate(pred_direction, size=gt_size, mode="bilinear", align_corners=True)
pred_mask = F.interpolate(pred_mask, size=gt_size, mode="bilinear", align_corners=True)
mask_loss = F.cross_entropy(pred_mask, gt_mask[:,0], weight=mask_weights, ignore_index=-1)
mask_threshold = 0.5
binary_pred_mask = torch.softmax(pred_mask, dim=1)[:, 1, :, :] > mask_threshold
gt_direction = self.angle_to_direction_label(
angle_map,
seg_label_map=seg_label,
extra_ignore_mask=(binary_pred_mask == 0),
return_tensor=True
)
direction_loss_mask = gt_direction != -1
direction_weights = self.calc_weights(gt_direction[direction_loss_mask], pred_direction.size(1))
direction_loss = F.cross_entropy(pred_direction, gt_direction[:,0], weight=direction_weights, ignore_index=-1)
offset = self._get_offset(pred_mask, pred_direction)
refine_map = self.shift(seg_logit, offset.permute(0,3,1,2))
seg_label = seg_label.squeeze(1)
loss['loss_seg'] = 0.8*self.loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index) + 5*mask_loss + 0.6*direction_loss + \
self.loss_decode(
refine_map,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss