evaluation/tiny_benchmark/maskrcnn_benchmark/layers/sigmoid_focal_loss.py (163 lines of code) (raw):

import torch from torch import nn from torch.autograd import Function from torch.autograd.function import once_differentiable from maskrcnn_benchmark import _C # TODO: Use JIT to replace CUDA implementation in the future. class _SigmoidFocalLoss(Function): @staticmethod def forward(ctx, logits, targets, gamma, alpha): ctx.save_for_backward(logits, targets) num_classes = logits.shape[1] ctx.num_classes = num_classes ctx.gamma = gamma ctx.alpha = alpha losses = _C.sigmoid_focalloss_forward( logits, targets, num_classes, gamma, alpha ) return losses @staticmethod @once_differentiable def backward(ctx, d_loss): logits, targets = ctx.saved_tensors num_classes = ctx.num_classes gamma = ctx.gamma alpha = ctx.alpha d_loss = d_loss.contiguous() d_logits = _C.sigmoid_focalloss_backward( logits, targets, d_loss, num_classes, gamma, alpha ) return d_logits, None, None, None, None sigmoid_focal_loss_cuda = _SigmoidFocalLoss.apply def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): num_classes = logits.shape[1] gamma = gamma[0] alpha = alpha[0] dtype = targets.dtype device = targets.device class_range = torch.arange(1, num_classes+1, dtype=dtype, device=device).unsqueeze(0) t = targets.unsqueeze(1) p = torch.sigmoid(logits) term1 = (1 - p) ** gamma * torch.log(p) term2 = p ** gamma * torch.log(1 - p) return -(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha) class SigmoidFocalLoss(nn.Module): def __init__(self, gamma, alpha): super(SigmoidFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha def forward(self, logits, targets): device = logits.device if logits.is_cuda: loss_func = sigmoid_focal_loss_cuda else: loss_func = sigmoid_focal_loss_cpu loss = loss_func(logits, targets, self.gamma, self.alpha) return loss.sum() def __repr__(self): tmpstr = self.__class__.__name__ + "(" tmpstr += "gamma=" + str(self.gamma) tmpstr += ", alpha=" + str(self.alpha) tmpstr += ")" return tmpstr from maskrcnn_benchmark.modeling.rpn.gaussian_net.gau_label_infer import three_points_solve class FixedIOULoss(nn.Module): def three_point_solve(self, li, lj, lk, a, b, eps=1e-6): lkj, lji = lk - lj, lj - li inverse_w2 = (lkj / b - lji / a) / (a + b) dx = -(w2 * lji / a + a) / 2 # dx = (lkj * a * a + lji * b * b) / (lji*b - lkj * a) / 2 return w2, dx def cross_points_set_solve_3d(self, L, points, a, b, step=1, solver=1): # points_set: (N, 3), # (c, y, x) """ L[cj, yj-a, xj] L[cj, yj, xj-a] L[cj, yj, xj] L[cj, yj, xj + b] L[cj, yj+b, xj] """ cj, yj, xj = points[:, 0], points[:, 1], points[:, 2] idx = torch.arange(len(points)) lx = L[cj, yj] # (N, W) lxi, lxj, lxk = lx[idx, xj - a], lx[idx, xj], lx[idx, xj + b] ly = L[cj, :, xj] # (N, H) not (H, N) lyi, lyj, lyk = ly[idx, yj - a], lxj, ly[idx, yj + b] li = torch.cat([lxi, lyi], dim=0) lj = torch.cat([lxj, lyj], dim=0) lk = torch.cat([lxk, lyk], dim=0) s, d = self.three_point_solve(li, lj, lk, a, b) n = len(s) // 2 w, h = s[:n], s[n:] dx, dy = d[:n], d[n:] # cx = xj.float() + dx # 1/2 cause use center point # cy = yj.float() + dy # x1 = cx - (w-1/step) / 2 # notice here # y1 = cy - (h-1/step) / 2 # return torch.stack([x1 * step, y1 * step, w * step, h * step, lxj], dim=1) # lxj == lyj return dx, dy, w, h def forward(self, bbox, target, sf=0.125): def center2corner(dx, dy, w, h): l = w / 2 - dx r = w / 2 + dx t = h / 2 - dy b = h / 2 + dy return l, t, r, b pred_l, pred_t, pred_r, pred_b = center2corner(*bbox) targ_l, targ_t, targ_r, targ_b = center2corner(*target) l_range = (0, 4) pred_l = pred_l.clamp(*l_range) pred_r = pred_r.clamp(*l_range) pred_t = pred_t.clamp(*l_range) pred_b = pred_b.clamp(*l_range) target_aera = target[2] * target[3] pred_aera = (pred_l + pred_r) * (pred_t + pred_b) w_intersect = torch.min(pred_l, targ_l) + torch.min(pred_r, targ_r) h_intersect = torch.min(pred_b, targ_b) + torch.min(pred_t, targ_t) area_intersect = w_intersect * h_intersect area_union = target_aera + pred_aera - area_intersect # iou_losses = -torch.log((area_intersect.clamp(0) + 1.0) / (area_union.clamp(0) + 1.0)) iou_losses = -torch.log(((area_intersect.clamp(0) + 1.0) / (area_union.clamp(0) + 1.0)).clamp(0.1)) # if iou_losses.max() > 10: # print("ok") # targ_w, targ_h = target[2], target[3] # l1_losses = 0. # for p, t, s in zip([pred_l, pred_t, pred_r, pred_b], # [targ_l, targ_t, targ_r, targ_b], # [targ_w, targ_h, targ_w, targ_h]): # l1_losses += torch.log(1 + 3 * smooth_l1((p - t) / s)) # l1_losses /= 4 # cause loss from 4 sub-loss: l, t, r, b # valid = ((bbox[2] > 0) & (bbox[3] > 0) & (pred_l > 0) & (pred_r > 0) & (pred_t > 0) & (pred_b > 0)).float() # assert (targ_h <= 0).sum() == 0 and (targ_w <= 0).sum() == 0 and (targ_l <= 0).sum() == 0 and (targ_r <= 0).sum() == 0 \ # and (targ_t <= 0).sum() == 0 and (targ_b <= 0).sum() == 0, "" # return iou_losses * valid, l1_losses * (1 - valid) return iou_losses * 0, iou_losses * 0 def smooth_l1(error, beta=1. / 9): """ very similar to the smooth_l1_loss from pytorch, but with the extra beta parameter """ n = torch.abs(error) cond = n < beta loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta) return loss class FixSigmoidFocalLoss(nn.Module): def __init__(self, gamma, alpha, sigma, fpn_strides, c, EPS=1e-6): super(FixSigmoidFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.sigma = sigma self.EPS = EPS self.fpn_strides = fpn_strides self.c = c # (0.5, 2, 1, 2) print("c1, c2, c3, c4 for pos loss:", self.c) self.g_mul_p = False self.iou_loss = FixedIOULoss() def forward(self, cls_logits, gau_logits, targets, valid=None): """ :param logits: shape=(B, H, W, C) :param targets: shape=(B, H, W, C) :return: """ gamma = self.gamma alpha = self.alpha eps = self.EPS c1, c2, c3, c4, c5 = self.c # num_classes = logits.shape[1] # dtype = targets.dtype # device = targets.device # # class_range = torch.arange(1, num_classes + 1, dtype=dtype, device=device).unsqueeze(0) q = targets p = torch.sigmoid(cls_logits) g = torch.sigmoid(gau_logits) # if self.g_mul_p: g = g * p # loss = -(q - p) ** gamma * (torch.log(p) * alpha + torch.log(1-p) * (1 - alpha)) # origin # loss = -(q - p) ** gamma * (q * torch.log(p) * alpha + (1 - q) * torch.log(1-p) * (1 - alpha)) # correct 1 # loss = -(q - p) ** gamma * (q * torch.log(p/(q+eps)) * alpha + (1 - q) * torch.log((1-p)/(1-q+eps)) * (1 - alpha)) # correct 2 # correct 3 # loss = -(q - p) ** gamma * (q * torch.log(p/(q+eps)) + (1 - q) * torch.log((1-p)/(1-q+eps))) # neg_loss = (1-alpha) * (q <= eps).float() * loss # pos_loss = alpha * (q > eps).float() * loss # correct 4 # loss = -(q - p) ** gamma * (q * torch.log(p/(q+eps)) + (1 - q) * torch.log((1-p)/(1-q+eps))) # neg_loss = (1-alpha) * (q <= eps).float() * loss # pos_loss = 4 * alpha * (q > eps).float() * loss # correct 5 # loss = - (q * torch.log(p) + (1 - q) * torch.log(1-p)) # correct 1-2 # neg_loss = (q <= eps).float() * (- torch.log(1 - p)) * (1 - alpha) * ((q - p) ** gamma) # q * |log(p) - log(q)|^2, cause inference need -log(p), so use log L2 Loss, q to weight like centerness. # pos_loss = q * (torch.log(p / (q + eps)) ** 2) * alpha # * (q > eps).float() # loss 1 # loss = (- q * torch.log(p) - (1 - q) * torch.log(1 - p)) * ((q - p) ** gamma) # neg_loss = (q <= eps).float() * loss * (1 - alpha) # pos_loss = (q > eps).float() * loss * alpha # loss 1, FL # loss = (- q * torch.log(p / (q + eps)) - (1 - q) * torch.log((1 - p)/(1 - q + eps))) * ((q - p) ** gamma) # neg_loss = (q <= eps).float() * loss * (1 - alpha) # pos_loss = (q > eps).float() * loss * alpha # print((q > eps).sum(), (q <= eps).sum()) # # loss 2, log loss # neg_loss = (q <= eps).float() * (- torch.log(1 - p) * (p ** gamma)) * (1 - alpha) # FL # pos_loss = (q * smooth_l1(torch.log(p / (q + eps)))) * alpha # smoothl1([ln(p) - ln(q)]) # should be (p + eps) / (q+ eps) # # loss3, log diff loss # # use p # neg_loss = (q <= eps).float() * (1 - alpha) * (- p ** gamma * torch.log(1 - p)) # pos_loss = (q > eps).float() * alpha * (- (1 - p) ** gamma * torch.log(p)) # # # use g # gau_neg_loss = (q <= eps).float() * (1 - alpha) * (- g ** gamma * torch.log(1 - g)) * c5 # fpn_stride, object_range, out_factor = self.fpn_strides[0], torch.Tensor([32, 64]), 2 # out_factor==2 means accept object range is [min/2, max*2] # # object_range[1] *= out_factor # # object_range[0] /= out_factor # # w**2=2/L * s**2(fpn_stride) in [32**2, 64**2], L in [2*(s/32)**2, 2*(s/64)**2], L*sf=[0.5, 2] # sf = object_range[0] / fpn_stride * object_range[1] / fpn_stride / 2 # 1/2 * (O1 * O2) / S**2=16, make 1/d2(log_q) to (0.5, 2) # factor = self.sigma * self.sigma * sf # 1/diff2(log_q) in (8, 32), log_q*16 make it in (0.5, 2) # # log_p = -torch.log(g + eps) * factor # log_q = -torch.log(q + eps) * factor # center_log_p, center_log_q = log_p[:, 1:-1, 1:-1, :], log_q[:, 1:-1, 1:-1, :] # # qx_diff1, qy_diff1 = (center_log_q - log_q[:, :-2, 1:-1, :]), (center_log_q - log_q[:, 1:-1, :-2, :]) # # px_diff1, py_diff1 = (center_log_p - log_p[:, :-2, 1:-1, :]), (center_log_p - log_p[:, 1:-1, :-2, :]) # left, right = lambda x: x[:, 1:-1, :-2, :], lambda x: x[:, 1:-1, 2:, :] # top, bottom = lambda x: x[:, :-2, 1:-1, :], lambda x: x[:, 2:, 1:-1, :] # qx_diff1 = center_log_q - left(log_q) # qy_diff1 = center_log_q - top(log_q) # px_diff1 = center_log_p - left(log_p) # py_diff1 = center_log_p - top(log_p) # qx_diff2 = left(log_q) + right(log_q) - 2 * center_log_q # qy_diff2 = top(log_q) + bottom(log_q) - 2 * center_log_q # px_diff2 = left(log_p) + right(log_p) - 2 * center_log_p # py_diff2 = top(log_p) + bottom(log_p) - 2 * center_log_p # # print('px_diff', px_diff1.max(), px_diff1[qx_diff1 > 0].mean()) # # print('qy_diff', qy_diff1.max(), qy_diff1[qy_diff1 > 0].mean()) # # valid_x = (q[:, :-2, 1:-1, :] > eps) & (q[:, 2:, 1:-1, :] > eps) # # valid_y = (q[:, 1:-1, :-2, :] > eps) & (q[:, 1:-1, 2:, :] > eps) # # # abs(dx) = s/8/2, (32, 64) -> t in (2, 4), (-tf/2, tf/2) # tf = (object_range[1] / fpn_stride) # dqx = -((qx_diff1+eps) / (qx_diff2+eps) + 0.5)[valid] / tf # dqy = -((qy_diff1+eps) / (qy_diff2+eps) + 0.5)[valid] / tf # dpx = -((px_diff1+eps) / (qx_diff2+eps) + 0.5)[valid] / tf # use qx_diff2, not px_diff2 to get smooth grad. # dpy = -((py_diff1+eps) / (qy_diff2+eps) + 0.5)[valid] / tf # x_loss = torch.log(1 + 3 * (dqx - dpx).clamp(-1, 1).abs()) # y_loss = torch.log(1 + 3 * (dqy - dpy).clamp(-1, 1).abs()) # xy_loss = (smooth_l1(x_loss, beta=0.25) + smooth_l1(y_loss, beta=0.25)) # # d2_range = 1./2/out_factor, 2 * out_factor # px_diff2 = px_diff2.clamp(*d2_range)[valid] # py_diff2 = py_diff2.clamp(*d2_range)[valid] # qx_diff2 = qx_diff2.clamp(*d2_range)[valid] # qy_diff2 = qy_diff2.clamp(*d2_range)[valid] # # gau_loss = (q[:, 1:-1, 1:-1, :] > 0).float() * smooth_l1(center_log_p - center_log_q) # wh_loss = (smooth_l1(c3 * torch.log(qx_diff2/px_diff2), beta=0.25) + # smooth_l1(c3 * torch.log(qy_diff2/py_diff2), beta=0.25)) # # # def ri(x): return round(x.item(), 3) # # print("neg_loss", ri(neg_loss.max()), ri(neg_loss.mean()), end=';') # # # # def ri(x): return round(x.item(), 3) if valid.sum() > 0 else 0 # # print('gau_loss', ri(gau_loss.max()), ri(gau_loss.mean()), end=";") # # print('wh_loss', ri(wh_loss.max()), ri(wh_loss.mean()), end=';') # # print('xy_loss', ri(xy_loss.max()), ri(xy_loss.mean()), ) # valid_q = q[:, 1:-1, 1:-1, :][valid] # gau_loss = q[:, 1:-1, 1:-1, :] * (c1*gau_loss) # wh_loss = valid_q * (c2*wh_loss) # xy_loss = valid_q * (c4*xy_loss) # return neg_loss.sum(), pos_loss.sum(), gau_neg_loss.sum() * 0, gau_loss.sum(), wh_loss.sum(), xy_loss.sum() # loss4, IOU neg_loss = (q <= eps).float() * (1 - alpha) * (- p ** gamma * torch.log(1 - p)) pos_loss = (q > eps).float() * alpha * (- (1 - p) ** gamma * torch.log(p)) g = g.permute((0, 3, 1, 2)) q = q.permute((0, 3, 1, 2)) valid = valid.permute((0, 3, 1, 2)) factor = self.sigma * self.sigma log_p = -torch.log(g + eps) * factor log_q = -torch.log(q + eps) * factor fpn_stride, object_range, out_factor = self.fpn_strides[0], torch.Tensor([32, 64]), 2 sf = 1 / ((object_range[0] / fpn_stride * object_range[1] / fpn_stride) ** 0.5) iou_losses = 0. l1_losses = 0. for b in range(len(valid)): idx = torch.nonzero(valid[b]) if len(idx) == 0: continue idx[:, 1:] += 1 p_bboxes = self.iou_loss.cross_points_set_solve_3d(log_p[b], idx, 1, 1, step=1, solver=1) q_bboxes = self.iou_loss.cross_points_set_solve_3d(log_q[b], idx, 1, 1, step=1, solver=1) iou_loss, l1_loss = self.iou_loss(p_bboxes, q_bboxes, sf) valid_q = q[b, :, 1:-1, 1:-1][valid[b]] iou_losses += (valid_q * iou_loss).sum() l1_losses += (valid_q * l1_loss).sum() def ri(x): return round(x.item(), 3) print("neg_loss", ri(neg_loss.max()), ri(neg_loss.mean()), end=';') print(iou_losses, l1_losses) return neg_loss.sum(), pos_loss.sum(), iou_losses * 0, l1_losses * 0 class L2LossWithLogit(nn.Module): def __init__(self): super(L2LossWithLogit, self).__init__() self.mse = nn.MSELoss(reduction='sum') def forward(self, logits, targets): p = torch.sigmoid(logits) return self.mse(p, targets)