evaluation/tiny_benchmark/maskrcnn_benchmark/modeling/rpn/locnet/loss.py (179 lines of code) (raw):

""" This file contains specific functions for computing losses of FCOS file """ import torch from torch.nn import functional as F from torch import nn from ..utils import concat_box_prediction_layers from maskrcnn_benchmark.layers import IOULoss from maskrcnn_benchmark.layers import SigmoidFocalLoss from maskrcnn_benchmark.layers.sigmoid_focal_loss import FixSigmoidFocalLoss, L2LossWithLogit from maskrcnn_benchmark.layers.ghm_loss import GHMC from maskrcnn_benchmark.modeling.matcher import Matcher from maskrcnn_benchmark.modeling.utils import cat from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist from .target_generator import * class LOCLossComputation(object): """ This class computes the FCOS losses. """ def __init__(self, cfg): # self.cls_loss_func = SigmoidFocalLoss( # cfg.MODEL.LOC.LOSS_GAMMA, # cfg.MODEL.LOC.LOSS_ALPHA # ) cls_loss_name = cfg.MODEL.LOC.CLS_LOSS self.cls_divide_pos_num = True if cls_loss_name == 'fixed_focal_loss': self.cls_loss_func = FixSigmoidFocalLoss( cfg.MODEL.LOC.LOSS_GAMMA, cfg.MODEL.LOC.LOSS_ALPHA ) elif cls_loss_name == 'L2': self.cls_loss_func = L2LossWithLogit() elif cls_loss_name == 'GHMC': self.cls_loss_func = GHMC(bins=cfg.MODEL.LOC.LOSS_GHMC_BINS, alpha=cfg.MODEL.LOC.LOSS_GHMC_ALPHA, momentum=cfg.MODEL.LOC.LOSS_GHMC_MOMENTUM) self.cls_divide_pos_num = False # we make use of IOU Loss for bounding boxes regression, # but we found that L1 in log scale can yield a similar performance self.box_reg_loss_func = IOULoss() if cfg.MODEL.LOC.TARGET_GENERATOR == 'fcos' and cfg.MODEL.LOC.FCOS_CENTERNESS: self.centerness_loss_func = nn.BCEWithLogitsLoss() self.prepare_targets = build_target_generator(cfg) self.cls_loss_weight = cfg.MODEL.LOC.CLS_WEIGHT self.centerness_weight_reg = cfg.MODEL.LOC.TARGET_GENERATOR == 'fcos' and cfg.MODEL.LOC.FCOS_CENTERNESS_WEIGHT_REG self.debug_vis_labels = cfg.MODEL.LOC.DEBUG.VIS_LABELS self.cls_divide_pos_sum = cfg.MODEL.LOC.DIVIDE_POS_SUM if self.cls_divide_pos_sum: self.cls_divide_pos_num = False def __call__(self, locations, box_cls, box_regression, centerness, targets): """ Arguments: locations (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) centerness (list[Tensor]) targets (list[BoxList]) Returns: cls_loss (Tensor) reg_loss (Tensor) centerness_loss (Tensor) """ N = box_cls[0].size(0) num_classes = box_cls[0].size(1) labels, reg_targets = self.prepare_targets(locations, targets) if self.debug_vis_labels: show_label_map(labels, box_cls) box_cls_flatten = [] box_regression_flatten = [] labels_flatten = [] reg_targets_flatten = [] for l in range(len(labels)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes)) box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels[l].reshape(-1, num_classes)) # changed reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) box_cls_flatten = torch.cat(box_cls_flatten, dim=0) box_regression_flatten = torch.cat(box_regression_flatten, dim=0) labels_flatten = torch.cat(labels_flatten, dim=0) reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) # class loss label_flatten_max = labels_flatten.max(dim=1)[0] pos_inds = torch.nonzero(label_flatten_max > 0).squeeze(1) pos_sum = labels_flatten.sum() cls_losses = self.cls_loss_func( box_cls_flatten, labels_flatten ) if isinstance(cls_losses, (list,)): for i in range(len(cls_losses)): if self.cls_divide_pos_num: cls_losses[i] /= (pos_inds.numel() + N) # add N to avoid dividing by a zero elif self.cls_divide_pos_sum: cls_losses[i] /= (pos_sum + N) else: if self.cls_divide_pos_num: cls_losses /= (pos_inds.numel() + N) # add N to avoid dividing by a zero elif self.cls_divide_pos_sum: cls_losses /= (pos_sum + N) # reg loss box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] if pos_inds.numel() > 0: if self.centerness_weight_reg: reg_weights = centerness_targets = self.prepare_targets.compute_centerness_targets(reg_targets_flatten) else: reg_weights = label_flatten_max[pos_inds] reg_loss = self.box_reg_loss_func( box_regression_flatten, reg_targets_flatten, reg_weights ) else: reg_loss = box_regression_flatten.sum() if isinstance(cls_losses, (list,)): losses = {"loss_cls{}".format(i): cls_loss * self.cls_loss_weight for i, cls_loss in enumerate(cls_losses)} losses['loss_reg'] = reg_loss else: losses = { "loss_cls": cls_losses * self.cls_loss_weight, "loss_reg": reg_loss } # centerness loss if centerness is not None: centerness_flatten = [centerness[l].reshape(-1) for l in range(len(centerness))] centerness_flatten = torch.cat(centerness_flatten, dim=0) centerness_flatten = centerness_flatten[pos_inds] if pos_inds.numel() > 0: centerness_loss = self.centerness_loss_func( centerness_flatten, centerness_targets ) else: centerness_loss = centerness_flatten.sum() losses["loss_centerness"] = centerness_loss return losses def make_location_loss_evaluator(cfg): loss_evaluator = LOCLossComputation(cfg) return loss_evaluator class LabelMapShower(object): def __init__(self, area_ths=1, show_iter=1): self.area_ths = area_ths self.show_iter = show_iter self.counter = 0 self.merge_levels = True self.show_classes = None # torch.Tensor([15]).long() - 1 self.merge_method = 'max' # merge class and fpn levels' method assert self.merge_method in ['max', 'sum'] def __call__(self, labels, box_cls): if (self.counter // self.area_ths + 1) % self.show_iter != 0: self.counter += 1 return self.counter += 1 labels = labels.copy() for i, (label, cls) in enumerate(zip(labels, box_cls)): # labels[i] = (label > 0).float().reshape((2, 1, cls.shape[-2], cls.shape[-1])) if self.show_classes is not None: label = label[:, self.show_classes] if self.merge_method == 'sum': labels[i] = label.sum(dim=1) elif self.merge_method == 'max': labels[i] = label.max(dim=1)[0] labels[i] = labels[i].reshape((cls.shape[0], 1, cls.shape[-2], cls.shape[-1])) if self.merge_levels: label_map = 0 else: label_maps = [] shape, pos_count = None, [] for i in range(0, len(labels)): label_sum = labels[i].sum() if shape is None: if label_sum > 0: shape = labels[i].shape label = labels[i] if not self.merge_levels: label_maps.append(label) else: label_map = label elif label_sum > 0: if self.merge_levels: label = F.upsample(labels[i], shape[2:], mode='bilinear') if self.merge_method == 'max': label_map = torch.max(torch.stack([label_map, label]), dim=0)[0] elif self.merge_method == 'sum': label_map += label else: label_maps.append(labels[i]) pos_count.append(int(label_sum.cpu().numpy())) # print(label_map.shape) import matplotlib.pyplot as plt import numpy as np if self.merge_levels: label_maps = [label_map] else: # ms = max([max(label_map.shape) for label_map in label_maps]) plt.figure(figsize=(5*len(label_maps), 5*1)) for i, label_map in enumerate(label_maps): label_map = F.upsample(label_map, (140, 100), mode='bilinear') label_map = label_map[0].permute((1, 2, 0)).cpu().numpy()[:, :, 0].astype(np.float32) ** 2 max_l = label_map.max() if max_l > 0: label_map /= max_l if len(label_maps) > 1: plt.subplot(1, len(label_maps), i + 1) plt.imshow(label_map) plt.title("pos_count:{} ".format(pos_count)) plt.show() show_label_map = LabelMapShower()