def my_cl_loss_fn3()

in utils/loss_fn.py [0:0]


def my_cl_loss_fn3(f_id: torch.Tensor, f_ood: torch.Tensor, labels: torch.Tensor, 
                   temperature=0.07, ls=False, tau_list=None, reweighting=False, w_list=None):
    '''
    A variant of supervised contrastive loss: 
    push ID samples from ID samples of different classes;
    push ID samples from OOD samples, but using different push strength according to prior distribution P(y);
    pull ID samples within the same classes.

    Args:
        f_id: features of ID_tail samples. Tensor. Shape=(N_id+N_ood,N_view,d)
        f_ood: features of OE samples. Tensor. Shape=(N_ood,d)
        labels: labels of ID_tail samples.
        ls: Bool. True if do label smoothing on CL loss labels.
        tau_list: list of floats. len=num_classes. Label smoothing parameter for each class based on prior p(y).
    '''

    f_id = f_id.view(f_id.shape[0], f_id.shape[1], -1) # shape=(N_id,2,d), i.e., 2 views

    N_id = f_id.shape[0]
    N_ood = f_ood.shape[0]
    labels = labels.contiguous().view(-1, 1)
    
    N_views = f_id.shape[1] # = 2
    f_id = torch.cat(torch.unbind(f_id, dim=1), dim=0) # shape=(N_id*2,d)

    # compute logits
    anchor_dot_contrast = torch.div(
        torch.matmul(f_id, torch.cat((f_id, f_ood), dim=0).T),
        temperature) # shape=(2N_id,2*N_id+N_ood)
    # for numerical stability
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) # dim=1 is the KL dim.
    logits_max: torch.Tensor
    logits = anchor_dot_contrast - logits_max.detach() # shape=(2N_id,2*N_id+N_ood)
    logits = logits.masked_select(~torch.eye(logits.shape[0], logits.shape[1], dtype=bool).to(logits.device)).view(logits.shape[0], logits.shape[1]-1)  # remove self-contrast cases (diag elements)
    
    # labels for CL:
    mask = torch.eq(labels, labels.T).float().to(labels.device) # shape=(N_id,N_id). 1 -> positive pair
    mask = mask.repeat(N_views, N_views) # shape=(2*N_id,2*N_id)
    mask = torch.cat((mask, torch.zeros(mask.shape[0],N_ood).to(mask.device)),dim=1) # all ood samples are negative samples to ID samples. shape=(2*N_id,2*N_id+N_ood)
    mask = mask.masked_select(~torch.eye(mask.shape[0], mask.shape[1], dtype=bool).to(mask.device)).view(mask.shape[0], mask.shape[1]-1) # remove self-contrast cases (diag elements). shape=(2*N_id,2*N_id-1+N_ood)
    cl_labels = nn.functional.normalize(mask, dim=1, p=1) # so that each row has sum 1. shape=(2*N_id,2*N_id-1+N_ood)

    # label smoothing:
    if ls:
        for _c, tau in enumerate(tau_list):
            _c_idx = labels == _c
            _c_idx = torch.cat([_c_idx,_c_idx], dim=0).squeeze()
            cl_labels[_c_idx] *= 1 - tau
            cl_labels[_c_idx,2*N_id:] = tau / N_ood

    # loss
    loss = torch.sum(-cl_labels * F.log_softmax(logits, dim=-1), dim=-1)

    # reweighting:
    if reweighting:
        assert ls is False
        for _c, w in enumerate(w_list):
            _c_idx = labels == _c
            if torch.sum(_c_idx) > 0:
                assert w > 0, ("Negative loss weight value detected: %s among %s when c=%s among %s" % (w, w_list, _c, torch.unique(labels)))
                _c_idx = torch.cat([_c_idx,_c_idx], dim=0).squeeze()
                loss[_c_idx] *= w
    
    # mean over the batch:
    loss = loss.mean() # average among all rows

    return loss