in models/base.py [0:0]
def parse_ada_ood_logits(self, ood_logits, metric, project=True):
if any(x in metric for x in ['bin_disc']):
pass
else:
if any(x in metric for x in ['msp', 'oe']):
ood_logits = F.softmax(ood_logits, dim=1).max(dim=1, keepdim=True).values - 1. / self.num_classes # MSP
elif any(x in metric for x in ['energy']):
ood_logits = torch.logsumexp(ood_logits, dim=1, keepdim=True)
elif any(x in metric for x in ['maha']):
pass # already calculated
elif any(x in metric for x in ['gradnorm']):
ood_logits = [self.calc_gradnorm_per_sample(f) for f in ood_logits]
ood_logits = torch.tensor(ood_logits).view(-1, 1).cuda()
else:
raise NotImplementedError(metric)
if project:
ood_logits = self.forward_aux_classifier(ood_logits)
return ood_logits.squeeze(1)