aiops/RCRank/model/loss/loss.py (50 lines of code) (raw):

import torch import torch.nn as nn class MarginLoss(nn.Module): def __init__(self, margin=0.03): super(MarginLoss, self).__init__() self.margin = margin def forward(self, pred, label): batch, label_num = label.shape label_sort, index_sort = torch.sort(label, dim=-1, descending=True) pred_sorted_by_true = pred.gather(dim=1, index=index_sort) pred_dis = pred_sorted_by_true.unsqueeze(2) - pred_sorted_by_true.unsqueeze(1) label_dis = label_sort.unsqueeze(2) - label_sort.unsqueeze(1) mask = torch.triu(torch.ones(label_num, label_num), diagonal=2) + torch.tril(torch.ones(label_num, label_num), diagonal=0) mask = mask.to(torch.bool).to(label.device) dis_dis = self.margin + label_dis - pred_dis dis_dis_mask = dis_dis.masked_fill(mask, 0) loss = torch.relu(dis_dis_mask) return loss.mean() class ListnetLoss(nn.Module): def __init__(self): super(ListnetLoss, self).__init__() def forward(self, pred, label): top1_target = torch.softmax(label, dim=-1) top1_predict = torch.softmax(pred, dim=-1) return torch.mean(-torch.sum(top1_target * torch.log(top1_predict))) class ListMleLoss(nn.Module): def __init__(self): super(ListMleLoss, self).__init__() def forward(self, y_pred, y_true, k=None): if k is not None: sublist_indices = (y_pred.shape[1] * torch.rand(size=k)).long() y_pred = y_pred[:, sublist_indices] y_true = y_true[:, sublist_indices] _, indices = y_true.sort(descending=True, dim=-1) pred_sorted_by_true = y_pred.gather(dim=1, index=indices) cumsums = pred_sorted_by_true.exp().flip(dims=[1]).cumsum(dim=1).flip(dims=[1]) listmle_loss = torch.log(cumsums + 1e-10) - pred_sorted_by_true return listmle_loss.sum(dim=1).mean() class ThresholdLoss(nn.Module): def __init__(self, threshold=0.05, margin_left=0.03, margin_right=0.03): super(ThresholdLoss, self).__init__() self.threshold = threshold self.margin_left = margin_left self.margin_right = margin_right def forward(self, pred, label): sign = ((label - self.threshold) + 1e-6) / torch.abs((label - self.threshold) + 1e-6) sign = sign.detach() ts_loss = (0.5 - 0.5 * sign) * (pred - self.threshold + self.margin_left) + (0.5 + 0.5 * sign) * (self.threshold - pred + self.margin_right) loss = torch.relu(ts_loss) return loss.mean()