def parse_logits()

in models/base.py [0:0]


    def parse_logits(self, all_logits, all_features, metric, num_in) \
            -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if any(x in metric for x in ['bin_disc']):
            in_sample_in_logits = all_logits[:num_in, :-1]
            in_sample_ood_logits = all_logits[:num_in, -1:]
            ood_sample_in_logits = all_logits[num_in:, :-1]
            ood_sample_ood_logits = all_logits[num_in:, -1:]
        elif any(x in metric for x in ['mc_disc']):
            in_sample_in_logits = all_logits[:num_in, :-2]
            in_sample_ood_logits = all_logits[:num_in, -2:]
            ood_sample_in_logits = all_logits[num_in:, :-2]
            ood_sample_ood_logits = all_logits[num_in:, -2:]
        elif any(x in metric for x in ['msp', 'oe', 'bkg_c', 'energy']):  
            in_sample_in_logits = all_logits[:num_in, :]
            in_sample_ood_logits = all_logits[:num_in, :]
            ood_sample_in_logits = all_logits[num_in:, :]
            ood_sample_ood_logits = all_logits[num_in:, :]
        elif any(x in metric for x in ['gradnorm']):
            in_sample_in_logits = all_logits[:num_in, :]
            in_sample_ood_logits = all_features[:num_in, :]
            ood_sample_in_logits = all_logits[num_in:, :]
            ood_sample_ood_logits = all_features[num_in:, :]
        elif any(x in metric for x in ['maha']):
            all_maha_scores = self.calc_maha_score(all_features)
            in_sample_in_logits = all_logits[:num_in, :]
            in_sample_ood_logits = all_maha_scores[:num_in, None]
            ood_sample_in_logits = all_logits[num_in:, :]
            ood_sample_ood_logits = all_maha_scores[num_in:, None]
        else:
            raise NotImplementedError('parse_logits %s' % metric)
        
        return in_sample_in_logits, in_sample_ood_logits, ood_sample_in_logits, ood_sample_ood_logits