evaluation/tiny_benchmark/maskrcnn_benchmark/modeling/rpn/gaussian_net/loss.py (220 lines of code) (raw):
"""
This file contains specific functions for computing losses of FCOS
file
"""
import torch
from torch.nn import functional as F
from torch import nn
from ..utils import concat_box_prediction_layers
from maskrcnn_benchmark.layers import IOULoss
from maskrcnn_benchmark.layers import SigmoidFocalLoss
from maskrcnn_benchmark.layers.sigmoid_focal_loss import FixSigmoidFocalLoss
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.modeling.utils import cat
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
INF = 100000000
class TargetGenerator(object):
def __init__(self, beta, num_classes, object_sizes_of_interest, label_radius=1.0):
self.beta = beta
self.inflection_point = 0.25
self.num_classes = num_classes
beta = self.beta
self.sigma = self.inflection_point * ((beta / (beta - 1)) ** (1.0/beta))
self.object_sizes_of_interest = object_sizes_of_interest
self.eps = 1e-6
self.label_radius = label_radius
def __call__(self, locations, targets):
object_sizes_of_interest = self.object_sizes_of_interest
cls_labels = []
matched_gt_idxs = []
self.care = [0, 0]
for l, locations_level in enumerate(locations):
cls_label, matched_gt_idx = self.prepare_target_per_level(locations_level, targets, object_sizes_of_interest[l], l)
cls_label = torch.cat(cls_label, dim=0) # cat all image label together
matched_gt_idx = torch.cat(matched_gt_idx, dim=0)
cls_labels.append(cls_label)
matched_gt_idxs.append(matched_gt_idx)
return cls_labels, matched_gt_idxs
def prepare_target_per_level(self, locations, targets, object_sizes, level=0):
"""
match_gt_idx = match_gt_idxs[img_idx]
match_gt_idx[loc_idx, class_id-1] = { # class_id start from 1
-1, if no object match
object_idx, if match object with object_idx
}
"""
beta = self.beta
sigma = self.sigma
cls_labels = []
matched_gt_idxs = []
xs, ys = locations[:, 0], locations[:, 1]
fpn_stride = xs[1] - xs[0]
for im_i in range(len(targets)):
# select object for this fpn level
targets_per_im = targets[im_i]
targets_per_im = targets_per_im.convert('xywh')
sizes = torch.sqrt((targets_per_im.bbox[:, 2] * targets_per_im.bbox[:, 3]))
min_e = targets_per_im.bbox[:, [2, 3]].min(dim=1)[0] / fpn_stride
# cause cross method need 3 point with same x or y.
is_card1_in_the_level = (sizes <= object_sizes[1]) & (sizes > object_sizes[0])
# is_card2_in_the_level = min_e < 6
# if level > 0:
# is_card1_in_the_level = is_card1_in_the_level & (min_e >= 3)
# is_card2_in_the_level = (min_e >= 3) & is_card2_in_the_level
# is_card_in_the_level = is_card1_in_the_level | is_card2_in_the_level
is_card_in_the_level = is_card1_in_the_level
targets_per_im = targets_per_im[is_card_in_the_level]
cls_label = torch.zeros(size=(len(xs), self.num_classes), device=xs.device)
matched_gt_idx = torch.zeros(size=(len(xs), self.num_classes), device=xs.device).long() - 1
if len(targets_per_im) == 0:
cls_labels.append(cls_label)
matched_gt_idxs.append(matched_gt_idx)
continue
# get gt-boxes infos
targets_per_im = targets_per_im.convert('xyxy')
bboxes = targets_per_im.bbox
cx = (bboxes[:, 0] + bboxes[:, 2]) / 2
cy = (bboxes[:, 1] + bboxes[:, 3]) / 2
W = bboxes[:, 2] - bboxes[:, 0] + 1
H = bboxes[:, 3] - bboxes[:, 1] + 1
# match locations to bbox one by one, and get the score
D = ((xs[:, None] - cx[None, :]).abs() / (sigma * W[None, :])) ** beta + \
((ys[:, None] - cy[None, :]).abs() / (sigma * H[None, :])) ** beta
Q = torch.exp(-D)
# clip gaussian range: in boxes make it positive, or get negative
Q = Q * self.is_in_boxes(xs, ys, bboxes)
# generate label map
dis = (xs[:, None] - cx[None, :]) ** 2 + (ys[:, None] - cy[None, :]) ** 2
labels_per_im = targets_per_im.get_field("labels").to(xs.device)
card_idx = torch.nonzero(is_card_in_the_level).squeeze(dim=1)
for c in set(labels_per_im):
targets_the_class = torch.nonzero(labels_per_im == c).squeeze(dim=1)
Qc = Q[:, targets_the_class]
cls_label[:, c-1], m_idx = Qc.max(dim=1)
matched_gt_idx[:, c-1] = torch.where(
cls_label[:, c-1] > self.eps, card_idx[targets_the_class[m_idx]], torch.LongTensor([-1]).to(m_idx.device))
matched_gt_idxs.append(matched_gt_idx)
cls_labels.append(cls_label)
return cls_labels, matched_gt_idxs
def is_in_boxes(self, xs, ys, bboxes):
cx = (bboxes[:, 0] + bboxes[:, 2]) / 2
cy = (bboxes[:, 1] + bboxes[:, 3]) / 2
W = bboxes[:, 2] - bboxes[:, 0] + 1
H = bboxes[:, 3] - bboxes[:, 1] + 1
W = (W * self.label_radius).clamp(3)
H = (H * self.label_radius).clamp(3)
x1 = cx - (W - 1) / 2
x2 = cx + (W - 1) / 2
y1 = cy - (H - 1) / 2
y2 = cy + (H - 1) / 2
l = xs[:, None] - x1[None]
t = ys[:, None] - y1[None]
r = x2[None] - xs[:, None]
b = y2[None] - ys[:, None]
reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
is_in_boxes = (reg_targets_per_im.min(dim=2)[0] > 0).float()
return is_in_boxes
class GAULossComputation(object):
"""
This class computes the Gaussian Net losses.
"""
def __init__(self, cfg):
object_sizes_of_interest = [
[-1, 64],
[64, 128],
[128, 256],
[256, 512],
[512, INF],
]
self.prepare_targets = TargetGenerator(cfg.MODEL.GAU.LABEL_BETA,
cfg.MODEL.GAU.NUM_CLASSES-1,
object_sizes_of_interest,
cfg.MODEL.GAU.LABEL_RADIUS)
self.cls_loss_func = FixSigmoidFocalLoss(
cfg.MODEL.GAU.LOSS_GAMMA,
cfg.MODEL.GAU.LOSS_ALPHA,
self.prepare_targets.sigma,
cfg.MODEL.GAU.FPN_STRIDES,
cfg.MODEL.GAU.C
)
# we make use of IOU Loss for bounding boxes regression,
# but we found that L1 in log scale can yield a similar performance
# self.box_reg_loss_func = IOULoss()
# self.centerness_loss_func = nn.BCEWithLogitsLoss()
self.vis_labels = cfg.MODEL.GAU.DEBUG.VIS_LABELS
def valid_pos(self, matches):
"""
:param match: list[Tensor], each Tensor shape is (B, H, W, C), list len is len(fpn_levels)
:return:
"""
valids = []
for l, match in enumerate(matches):
center = match[:, 1:-1, 1:-1, :]
top = match[:, :-2, 1:-1, :]
bottom = match[:, 2:, 1:-1, :]
left = match[:, 1:-1, :-2, :]
right = match[:, 1:-1, 2:, :]
valid = (left == right) & (top == bottom) & (left == top) & (left == center) & (center >= 0)
valids.append(valid)
return valids
def reshape(self, labels, cls_logits, gau_logits, matched_gt_idxs):
"""
list of flatten tensor shape (B, M) to list of shape(B, H, W, C)
:param labels:
:param cls_logits:
:param matched_gt_idxs:
:return:
"""
num_classes = cls_logits[0].size(1)
cls_flatten = []
gau_flatten = []
labels_flatten = []
matched_flatten = []
for l in range(len(labels)):
N, C, H, W = cls_logits[l].shape
cls_flatten.append(cls_logits[l].permute(0, 2, 3, 1))
gau_flatten.append(gau_logits[l].permute(0, 2, 3, 1))
labels_flatten.append(labels[l].reshape(N, H, W, num_classes))
matched_flatten.append(matched_gt_idxs[l].reshape(N, H, W, num_classes))
return cls_flatten, gau_flatten, labels_flatten, matched_flatten
def __call__(self, locations, logits, targets):
"""
Arguments:
locations (list[BoxList])
box_cls (list[Tensor])
box_regression (list[Tensor])
centerness (list[Tensor])
targets (list[BoxList])
Returns:
cls_loss (Tensor)
reg_loss (Tensor)
centerness_loss (Tensor)
"""
cls_logits, gau_logits = logits
N = cls_logits[0].size(0)
labels, matched_gt_idxs = self.prepare_targets(locations, targets)
# list of flatten tensor shape (B, M) to list of shape(B, H, W, C)
cls_logits, gau_logits, lables, matched_gt_idxs = self.reshape(labels, cls_logits, gau_logits, matched_gt_idxs)
valids_pos = self.valid_pos(matched_gt_idxs)
show_label_map(lables, matched_gt_idxs, valids_pos, cls_logits, gau_logits)
if self.vis_labels: show_label_map(lables, matched_gt_idxs, valids_pos, cls_logits, gau_logits)
# box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
# labels_flatten = torch.cat(labels_flatten, dim=0)
#
# pos_inds = torch.nonzero(labels_flatten > 0)
# norm = labels_flatten[pos_inds].sum() + N
# neg_loss, pos_loss = self.cls_loss_func(
# box_cls_flatten,
# labels_flatten
# )
# neg_loss /= norm / 200
# pos_loss /= norm # add N to avoid dividing by a zero
loss = {}
# norm = sum([label.sum() / (4**i) for i, label in enumerate(labels_flatten)]) + N
norm = sum([label.sum() for i, label in enumerate(lables)]) + N
npos = sum([(label > 0).sum() for i, label in enumerate(lables)]) + N
losses_fpn = [0.] * 4 # 6
norms = [npos, npos] + [norm] * 2 # 4
if self.vis_labels:
print(sum([(label > 0).sum() for i, label in enumerate(lables)]), '+', end='')
print(sum([label.sum() for i, label in enumerate(lables)]), '+', end='')
print(norm, )
for i, (label, box_cls, gau) in enumerate(zip(lables, cls_logits, gau_logits)):
losses = self.cls_loss_func(box_cls, gau, label, valids_pos[i])
for i in range(len(losses)):
losses_fpn[i] += losses[i] / norms[i]
# loss.update({
# "neg_loss{}".format(i): neg_loss / norm,
# "pos_loss{}".format(i): pos_loss / norm
# })
loss.update({
"neg_loss": losses_fpn[0],
"pos_loss": losses_fpn[1],
"iou_loss": losses_fpn[2],
"l1_loss": losses_fpn[3],
# "gau_neg_loss": losses_fpn[2],
# "gau_loss": losses_fpn[3],
# "wh_loss": losses_fpn[4],
# "xy_loss": losses_fpn[5],
})
return loss # neg_losses, pos_losses
def make_gau_loss_evaluator(cfg):
loss_evaluator = GAULossComputation(cfg)
return loss_evaluator
iter = 0
def show_label_map(labels, matched_gt_idxs, valids_pos, box_cls, gaus):
import matplotlib.pyplot as plt
import numpy as np
global iter
iter += 1
if iter % 20 != 0: return
new_labels = []
new_clses = []
new_match = []
new_valid = []
new_gaus = []
for i, (label, cls) in enumerate(zip(labels, box_cls)):
new_labels.append(label.cpu().detach().numpy())
new_clses.append(cls.cpu().detach().numpy())
new_gaus.append(gaus[i].cpu().detach().numpy())
new_match.append(matched_gt_idxs[i].float().cpu().detach().numpy())
new_valid.append(valids_pos[i].float().cpu().detach().numpy())
labels = new_labels
box_cls = new_clses
match = new_match
valid = new_valid
gaus = new_gaus
sigmoid = lambda x: 1 / (1 + np.exp(-x))
N = sum([(label[0].sum(axis=(0, 1)) > 0).sum() for label in labels])
C = 4
n = 1
# print(N)
plt.figure(figsize=(12, N*4))
for l, label in enumerate(labels):
for c in range(label.shape[-1]):
if label[0, :, :, c].sum() > 0:
plt.subplot(N, C, n)
# print(label.shape)
plt.imshow(np.log(label[0, :, :, c] + 0.01), vmin=np.log(0.01), vmax=np.log(1.01)) # if no vmin vmax set, will linear normal
plt_str = ("pos_count:{}; {}, {}".format((label[0, :, :, c] > 0).sum(), l, c))
plt.title(plt_str)
plt.subplot(N, C, n + 1)
pred = sigmoid(box_cls[l][0, :, :, c])
plt.imshow(np.log(pred+0.01), vmin=np.log(0.01), vmax=np.log(1.01)) # if no vmin vmax set, will linear normal, not show absolute value
plt.title("p: [{:.4f}, {:.4f}]".format(pred.min(), pred.max()))
plt.subplot(N, C, n + 2)
pred = sigmoid(gaus[l][0, :, :, c])
plt.imshow(np.log(pred+0.01), vmin=np.log(0.01), vmax=np.log(1.01)) # if no vmin vmax set, will linear normal, not show absolute value
plt.title("g: [{:.4f}, {:.4f}]".format(pred.min(), pred.max()))
plt.subplot(N, C, n + 3)
pred = (sigmoid(gaus[l][0, :, :, c]) * sigmoid(box_cls[l][0, :, :, c])) ** 0.5
plt.imshow(np.log(pred+0.01), vmin=np.log(0.01), vmax=np.log(1.01)) # if no vmin vmax set, will linear normal, not show absolute value
plt.title("(p*g)^(0.5): [{:.4f}, {:.4f}]".format(pred.min(), pred.max()))
# print(plt_str)
n += C
# plt.show()
plt.savefig("outputs/pascal/gau/tmp/iter_png/{}.png".format(iter))
# label_map = 0
# shape, pos_count = None, []
# label_maps = []
# for i in range(0, len(labels)):
# label_sum = labels[i].sum()
# if shape is None:
# if label_sum > 0:
# shape = labels[i].shape
# label = labels[i]
# label_maps.append(label)
# else:
# label = F.upsample(labels[i], shape[2:], mode='bilinear')
# # label_map += label
# label_maps.append(label)
# pos_count.append(int(label_sum.cpu().numpy()))
# # print(label_map.shape)
#
# for i, label_map in enumerate(label_maps):
# label_map = label_map[0].permute((1, 2, 0)).cpu().numpy()[:, :, 0].astype(np.float32)
# max_l = label_map.max()
# if max_l > 0:
# label_map /= max_l
#
# plt.subplot(len(label_maps), 1, i+1)
# plt.imshow(label_map)
# plt.title("pos_count:{} ".format(pos_count))
# plt.show()