in ood/metrics.py [0:0]
def get_msp_scores(logits, ood_logits=None, method='MCM', ret_near_ood=False):
assert method in ['MCM', 'Energy', 'MaxScore',
'MCM_Full', 'MCM_Full_Hard', 'MCM_Full_Soft', 'MCM_Full_Neg', 'MCM_Pair_Hard', 'MCM_Pair_Soft',
'MCM_Pair_Scale', 'MCM_Full_Scale'], \
'OOD inference method %s has not been implemented.' % method
probs = F.softmax(logits, dim=1)
msp = probs.max(dim=1).values
inconsistent = None
if method == 'Energy':
tau = 100.
scores = -torch.logsumexp(logits * tau, dim=1)
elif method == 'MCM':
# assert ood_logits is None
scores = - msp # higher score means OOD
else:
assert ood_logits is not None
pred = logits.argmax(dim=1)
xrange = torch.arange(logits.shape[0])
if 'MCM_Full' in method:
full_logits = torch.cat((logits, ood_logits), dim=1)
full_probs = F.softmax(full_logits, dim=1)
full_pred = full_logits.argmax(dim=1)
inconsistent = pred != full_pred # probs < full_probs/ood_probs
if 'Neg' in method:
cls_num = logits.shape[1]
scores = full_probs[:, cls_num:].sum(dim=1) # higher score means OOD
else:
scores = - full_probs[xrange, pred]
if 'Hard' in method:
scores[inconsistent] = 0. # higher score means OOD
elif 'Soft' in method:
# negative score adding a positive delta brings a higer score
scores += full_probs[xrange, full_pred] - full_probs[xrange, pred] # higher score means OOD
elif 'Scale' in method:
max_id_sim, max_ood_sim = logits.max(dim=1)[0], ood_logits.max(dim=1)[0]
pair_logits = torch.stack((max_id_sim, max_ood_sim), dim=1)
scale = F.softmax(pair_logits, dim=1)[:, :1].clamp(min=0.5)
# scale = F.softmax(pair_logits * 8., dim=1)[:, :1].clamp(min=0.5) / 8.
full_probs = F.softmax(full_logits * scale, dim=1)
scores = - full_probs[xrange, pred]
else:
assert method == 'MCM_Full'
elif 'MCM_Pair' in method:
scores = - msp # higher score means OOD
pair_logits = torch.stack((logits[xrange, pred], ood_logits[xrange, pred]), dim=1) # shape(nb,2)
inconsistent = pair_logits[:, 0] < pair_logits[:, 1] # id_sim < ood_sim
pair_probs = F.softmax(pair_logits, dim=1)
if 'Hard' in method:
scores[inconsistent] = 0. # higher score means OOD
elif 'Soft' in method:
# negative score multiplying a smaller value brings a higher score
scores *= pair_probs[:, 0].clamp(min=0.5) # higher score means OOD
elif 'Scale' in method:
# print(F.softmax(pair_logits, dim=1)[:, :1].detach())
scale = F.softmax(pair_logits, dim=1)[:, :1].clamp(min=0.5) # 498
# scale = F.softmax(pair_logits * 8., dim=1)[:, :1].clamp(min=0.) / 8.
probs = F.softmax(logits * scale, dim=1)
scores = - probs[xrange, pred]
else:
raise NotImplementedError
elif method == 'MaxScore':
pair_logits = torch.stack((logits[xrange, pred], ood_logits[xrange, pred]), dim=1) # shape(nb,2)
pair_probs = F.softmax(pair_logits*10., dim=1)
scores = -pair_probs[:, 0] * logits[xrange, pred]
if ret_near_ood:
return scores, inconsistent
else:
return scores