utils/loss_fn.py (115 lines of code) (raw):

''' my_cl_loss_fn2 function is adapted from https://github.com/HobbitLong/SupContrast/blob/master/losses.py which is originally licensed under BSD-2-Clause. ''' # Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn import torch.nn.functional as F import numpy as np # from timm.loss import SoftTargetCrossEntropy def prior_to_tau(prior, tau0=0.1): ''' Args: prior: iterable with len=num_classes Returns: tau: iterable with len=num_classes ''' tau = tau0 / (prior[0]-prior[-1]) * (prior-prior[-1]) return tau def my_cl_loss_fn3(f_id: torch.Tensor, f_ood: torch.Tensor, labels: torch.Tensor, temperature=0.07, ls=False, tau_list=None, reweighting=False, w_list=None): ''' A variant of supervised contrastive loss: push ID samples from ID samples of different classes; push ID samples from OOD samples, but using different push strength according to prior distribution P(y); pull ID samples within the same classes. Args: f_id: features of ID_tail samples. Tensor. Shape=(N_id+N_ood,N_view,d) f_ood: features of OE samples. Tensor. Shape=(N_ood,d) labels: labels of ID_tail samples. ls: Bool. True if do label smoothing on CL loss labels. tau_list: list of floats. len=num_classes. Label smoothing parameter for each class based on prior p(y). ''' f_id = f_id.view(f_id.shape[0], f_id.shape[1], -1) # shape=(N_id,2,d), i.e., 2 views N_id = f_id.shape[0] N_ood = f_ood.shape[0] labels = labels.contiguous().view(-1, 1) N_views = f_id.shape[1] # = 2 f_id = torch.cat(torch.unbind(f_id, dim=1), dim=0) # shape=(N_id*2,d) # compute logits anchor_dot_contrast = torch.div( torch.matmul(f_id, torch.cat((f_id, f_ood), dim=0).T), temperature) # shape=(2N_id,2*N_id+N_ood) # for numerical stability logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) # dim=1 is the KL dim. logits_max: torch.Tensor logits = anchor_dot_contrast - logits_max.detach() # shape=(2N_id,2*N_id+N_ood) logits = logits.masked_select(~torch.eye(logits.shape[0], logits.shape[1], dtype=bool).to(logits.device)).view(logits.shape[0], logits.shape[1]-1) # remove self-contrast cases (diag elements) # labels for CL: mask = torch.eq(labels, labels.T).float().to(labels.device) # shape=(N_id,N_id). 1 -> positive pair mask = mask.repeat(N_views, N_views) # shape=(2*N_id,2*N_id) mask = torch.cat((mask, torch.zeros(mask.shape[0],N_ood).to(mask.device)),dim=1) # all ood samples are negative samples to ID samples. shape=(2*N_id,2*N_id+N_ood) mask = mask.masked_select(~torch.eye(mask.shape[0], mask.shape[1], dtype=bool).to(mask.device)).view(mask.shape[0], mask.shape[1]-1) # remove self-contrast cases (diag elements). shape=(2*N_id,2*N_id-1+N_ood) cl_labels = nn.functional.normalize(mask, dim=1, p=1) # so that each row has sum 1. shape=(2*N_id,2*N_id-1+N_ood) # label smoothing: if ls: for _c, tau in enumerate(tau_list): _c_idx = labels == _c _c_idx = torch.cat([_c_idx,_c_idx], dim=0).squeeze() cl_labels[_c_idx] *= 1 - tau cl_labels[_c_idx,2*N_id:] = tau / N_ood # loss loss = torch.sum(-cl_labels * F.log_softmax(logits, dim=-1), dim=-1) # reweighting: if reweighting: assert ls is False for _c, w in enumerate(w_list): _c_idx = labels == _c if torch.sum(_c_idx) > 0: assert w > 0, ("Negative loss weight value detected: %s among %s when c=%s among %s" % (w, w_list, _c, torch.unique(labels))) _c_idx = torch.cat([_c_idx,_c_idx], dim=0).squeeze() loss[_c_idx] *= w # mean over the batch: loss = loss.mean() # average among all rows return loss def oe_loss_fn(logits: torch.Tensor): ''' The original instable implementation. torch.logsumexp is not numerically stable. ''' return -(logits.mean(1) - torch.logsumexp(logits, dim=1)).mean() def stable_imbce(logits: torch.Tensor, labels: torch.Tensor, beta: torch.Tensor, eps=1e-4): # l0 = ((1. - labels) * logits + (1 + (-logits).exp()).log()).mean() # l1 = (beta.log() + (1 + (-logits).exp()).log() - (1 - labels) * (beta * (1 + (-logits).exp()) - 1.).log()).mean() max_val = (-logits).clamp(min=0.) loss = 0. loss += beta.clamp(min=eps).log() # loss += - labels * beta.clamp(min=eps).log() loss += labels * max_val + ((-max_val).exp() + (- logits - max_val).exp()).log() loss += (labels - 1.) * ((beta - 1.) * (-max_val).exp() + beta * (- logits - max_val).exp()).clamp(min=eps).log() return loss.mean() def normalize(nparray, order=2, axis=0): """Normalize a N-D numpy array along the specified axis.""" norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True) return nparray / (norm + np.finfo(np.float32).eps) def compute_dist(array1, array2, type='euclidean'): """Compute the euclidean or cosine distance of all pairs. Args: array1: numpy array with shape [m1, n] array2: numpy array with shape [m2, n] type: one of ['cosine', 'euclidean'] Returns: numpy array with shape [m1, m2] """ assert type in ['cosine', 'euclidean'] if type == 'cosine': array1 = normalize(array1, axis=1) array2 = normalize(array2, axis=1) dist = np.matmul(array1, array2.T) return dist else: # # shape [m1, 1] # square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis] # # shape [1, m2] # square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...] # squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2 # squared_dist[squared_dist < 0] = 0 # dist = np.sqrt(squared_dist) # shape [m1, 1] square1 = torch.unsqueeze(torch.sum(torch.square(array1), axis=1), 1) # shape [1, m2] square2 = torch.unsqueeze(torch.sum(torch.square(array2), axis=1), 0) squared_dist = - 2 * torch.matmul(array1, array2.T) + square1 + square2 squared_dist[squared_dist < 0] = 0 dist = torch.sqrt(squared_dist) return dist if __name__ == '__main__': labels = torch.cat((torch.ones((10, 1)), torch.zeros((10, 1)))).cuda() logits = torch.rand((20, 1)).cuda() * 5. l1 = F.binary_cross_entropy_with_logits(logits, labels) max_val = (-logits).clamp(0) l2 = (1. - labels) * logits + max_val + (torch.exp(-max_val) + torch.exp(-logits - max_val)).log() l2 = l2.mean() print(l1, l2) l3 = - logits * labels + torch.log(1 + torch.exp(logits)) l3 = l3.mean() print(l1, l2, l3) beta = torch.rand_like(logits).clamp(min=1e-4) * 1. + 1. l4 = F.binary_cross_entropy(logits.sigmoid() / beta, labels) l5 = stable_imbce(logits, labels, beta) print(l4, l5) l6 = l1 + (beta.log() - (1. - labels) * ((beta - 1.) * logits.exp() + beta).log()).mean() print(l6) delta = (beta + (beta - 1.) * torch.exp(logits)) #.clamp(min=1e-1, max=1e+1) l7 = F.binary_cross_entropy_with_logits(logits - delta.log(), labels) print(l7)