def calc_aux_loss()

in models/base.py [0:0]


    def calc_aux_loss(self, p4, labels, args):
        device = p4.device
        aux_loss = torch.zeros(1, device=device)
        num_in = labels.shape[0]

        if 'pascl' == args.aux_ood_loss:
            if not hasattr(self, 'cl_loss_weights'):
                _sigmoid_x = torch.linspace(-1, 1, self.num_classes).to(device)
                _d = -2 * args.k + 1 - 0.001 # - 0.001 to make _d<-1 when k=1
                self.register_buffer('cl_loss_weights', torch.sign((_sigmoid_x-_d)))

            tail_idx = labels >= round((1-args.k)*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1
            all_f = self.forward_projection(p4)
            f_id_view1, f_id_view2 = all_f[0:num_in], all_f[num_in:2*num_in]
            f_id_tail_view1 = f_id_view1[tail_idx] # i.e., 6,7,8,9 in cifar10
            f_id_tail_view2 = f_id_view2[tail_idx] # i.e., 6,7,8,9 in cifar10
            labels_tail = labels[tail_idx]
            f_ood = all_f[2*num_in:]
            if torch.sum(tail_idx) > 0:
                aux_loss += my_cl_loss_fn3(
                    torch.stack((f_id_tail_view1, f_id_tail_view2), dim=1), f_ood, labels_tail, temperature=args.T,
                    reweighting=True, w_list=self.cl_loss_weights
                )
        elif 'simclr' == args.aux_ood_loss:
            ood_logits = self.projection(p4)[:, 0]
            ood_labels = torch.zeros((len(p4),), device=device)
            assert len(p4) > num_in*2
            ood_labels[:num_in*2] = 1.
            aux_loss = F.binary_cross_entropy_with_logits(ood_logits, ood_labels)
        
        elif 'cocl' == args.aux_ood_loss:
            from utils.loss_fn import compute_dist

            Lambda2, Lambda3 = 0.05, 0.1
            temperature, margin = 0.07, 1.0
            headrate, tailrate = 0.4, 0.4

            f_id_view = p4[:2*num_in]
            f_ood = p4[2*num_in:]
            head_idx = labels<= round(headrate*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1
            tail_idx = labels>= round((1-tailrate)*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1
            f_id_head_view = f_id_view[head_idx] # i.e., 6,7,8,9 in cifar10
            f_id_tail_view = f_id_view[tail_idx] # i.e., 6,7,8,9 in cifar10
            labels_tail = labels[tail_idx]

            # OOD-aware tail class prototype learning
            if len(f_id_tail_view) > 0 and Lambda2 > 0:
                ## TODO
                raise NotImplementedError
                logits = self.forward_weight(f_id_tail_view, f_ood, temperature=temperature)
                tail_loss = F.cross_entropy(logits, labels_tail-round((1-tailrate)*self.num_classes))
            else:
                tail_loss = torch.zeros((1, ), device=device)

            # debiased head class learning
            if Lambda3 > 0:
                dist1 = compute_dist(f_ood, f_ood)
                _, dist_max1 = torch.max(dist1, 1)
                positive = f_ood[dist_max1]

                dist2 = torch.randint(low = 0, high= len(f_id_head_view), size = (1, len(f_ood))).to(device).squeeze()
                negative = f_id_head_view[dist2]
                
                triplet_loss = torch.nn.TripletMarginLoss(margin=margin)
                head_loss = triplet_loss(f_ood, positive, negative)
            else:
                head_loss = torch.zeros((1, ), device=device)
            
            aux_loss = tail_loss * Lambda2 + head_loss * Lambda3
        
        return aux_loss