def calc_loss()

in models/base.py [0:0]


    def calc_loss(self, logits, p4, labels, adjustments, args, use_imood=True):
        in_labels = torch.cat([labels, labels], dim=0)
        num_sample, total_num_in = logits.shape[0], in_labels.shape[0]
        assert num_sample > total_num_in
        device = in_labels.device
        metric = args.ood_metric

        in_sample_in_logits, in_sample_ood_logits, ood_sample_in_logits, ood_sample_ood_logits \
            = self.parse_logits(logits, p4, metric, total_num_in)

        in_loss, ood_loss, aux_ood_loss = \
            torch.zeros((1,), device=device), torch.zeros((1,), device=device), torch.zeros((1,), device=device)

        if not metric.startswith('ada_'):

            in_loss += F.cross_entropy(in_sample_in_logits + adjustments, in_labels)

            if metric == 'oe':
                ood_loss += -(ood_sample_ood_logits.mean(1) - torch.logsumexp(ood_sample_ood_logits, dim=1)).mean()
            elif metric == 'energy':
                Ec_out = -torch.logsumexp(ood_sample_ood_logits, dim=1)
                Ec_in = -torch.logsumexp(in_sample_ood_logits, dim=1)
                m_in, m_out = -23 if self.num_classes == 10 else -27, -5  # cifar10/100
                # 0.2 * 0.5 = 0.1, the default loss scale in official Energy OOD
                ood_loss += (torch.pow(F.relu(Ec_in-m_in), 2).mean() + torch.pow(F.relu(m_out-Ec_out), 2).mean()) * 0.2
            elif metric == 'bkg_c':
                ood_labels = torch.full_like(in_labels[:1], self.num_classes)
                ood_loss += F.cross_entropy(ood_sample_ood_logits, ood_labels)
            elif metric == 'bin_disc':
                ood_labels = torch.zeros((num_sample,), device=device)
                ood_labels[:total_num_in] = 1.
                ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0).squeeze(1)
                ood_loss += F.binary_cross_entropy_with_logits(ood_logits, ood_labels)
            elif metric == 'mc_disc':
                ood_labels = torch.zeros((num_sample,), device=device)
                ood_labels[:total_num_in] = 1.  # id: cls0; ood: cls1
                ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0)
                ood_loss += F.cross_entropy(ood_logits, ood_labels)
            else:
                raise NotImplementedError(metric)

        else:
            ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0)
            ood_logits = self.parse_ada_ood_logits(ood_logits, metric)

            ood_labels = torch.zeros((num_sample,), device=device)
            ood_labels[:total_num_in] = 1.

            cls_prior = F.softmax(adjustments, dim=1)

            min_thresh = 1e-4
            lambd = self.forward_lambda(p4).squeeze().clamp(min=min_thresh)

            smoothing = 0.2
            m_in_labels: torch.Tensor = F.one_hot(in_labels, num_classes=self.num_classes)
            in_posterior = m_in_labels * (1 - smoothing) + smoothing / self.num_classes
            ood_posterior = F.softmax(ood_sample_in_logits.detach(), dim=1)
            cls_posterior = torch.cat((in_posterior, ood_posterior))
            beta = (lambd * cls_posterior / cls_prior).mean(dim=1) #.clamp(min=1e-1, max=1e+1)
            ood_loss += (beta.log() + ood_logits.detach().sigmoid().log()).relu().mean()
            
            beta = beta.detach()
            delta = (beta + (beta - 1.) * torch.exp(ood_logits.detach())).clamp(min=1e-1, max=1e+1)
            delta = torch.cat((delta[:total_num_in].clamp(min=1.), 
                                delta[total_num_in:].clamp(max=1.)), dim=0)
            ood_logits = ood_logits - delta.log()

            ood_loss += F.binary_cross_entropy_with_logits(ood_logits, ood_labels)

            if metric == 'ada_oe':  # add original OE loss 
                ood_loss += -(ood_sample_ood_logits.mean(1) - torch.logsumexp(ood_sample_ood_logits, dim=1)).mean()

            in_sample_in_logits = in_sample_in_logits + adjustments
            in_loss += F.cross_entropy(in_sample_in_logits, in_labels)
        
        aux_ood_loss += self.calc_aux_loss(p4, labels, args)

        return in_loss, ood_loss, aux_ood_loss