in models/base.py [0:0]
def parse_logits(self, all_logits, all_features, metric, num_in) \
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if any(x in metric for x in ['bin_disc']):
in_sample_in_logits = all_logits[:num_in, :-1]
in_sample_ood_logits = all_logits[:num_in, -1:]
ood_sample_in_logits = all_logits[num_in:, :-1]
ood_sample_ood_logits = all_logits[num_in:, -1:]
elif any(x in metric for x in ['mc_disc']):
in_sample_in_logits = all_logits[:num_in, :-2]
in_sample_ood_logits = all_logits[:num_in, -2:]
ood_sample_in_logits = all_logits[num_in:, :-2]
ood_sample_ood_logits = all_logits[num_in:, -2:]
elif any(x in metric for x in ['msp', 'oe', 'bkg_c', 'energy']):
in_sample_in_logits = all_logits[:num_in, :]
in_sample_ood_logits = all_logits[:num_in, :]
ood_sample_in_logits = all_logits[num_in:, :]
ood_sample_ood_logits = all_logits[num_in:, :]
elif any(x in metric for x in ['gradnorm']):
in_sample_in_logits = all_logits[:num_in, :]
in_sample_ood_logits = all_features[:num_in, :]
ood_sample_in_logits = all_logits[num_in:, :]
ood_sample_ood_logits = all_features[num_in:, :]
elif any(x in metric for x in ['maha']):
all_maha_scores = self.calc_maha_score(all_features)
in_sample_in_logits = all_logits[:num_in, :]
in_sample_ood_logits = all_maha_scores[:num_in, None]
ood_sample_in_logits = all_logits[num_in:, :]
ood_sample_ood_logits = all_maha_scores[num_in:, None]
else:
raise NotImplementedError('parse_logits %s' % metric)
return in_sample_in_logits, in_sample_ood_logits, ood_sample_in_logits, ood_sample_ood_logits