ood/metrics.py (109 lines of code) (raw):
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
import numpy as np
def get_msp_scores(logits, ood_logits=None, method='MCM', ret_near_ood=False):
assert method in ['MCM', 'Energy', 'MaxScore',
'MCM_Full', 'MCM_Full_Hard', 'MCM_Full_Soft', 'MCM_Full_Neg', 'MCM_Pair_Hard', 'MCM_Pair_Soft',
'MCM_Pair_Scale', 'MCM_Full_Scale'], \
'OOD inference method %s has not been implemented.' % method
probs = F.softmax(logits, dim=1)
msp = probs.max(dim=1).values
inconsistent = None
if method == 'Energy':
tau = 100.
scores = -torch.logsumexp(logits * tau, dim=1)
elif method == 'MCM':
# assert ood_logits is None
scores = - msp # higher score means OOD
else:
assert ood_logits is not None
pred = logits.argmax(dim=1)
xrange = torch.arange(logits.shape[0])
if 'MCM_Full' in method:
full_logits = torch.cat((logits, ood_logits), dim=1)
full_probs = F.softmax(full_logits, dim=1)
full_pred = full_logits.argmax(dim=1)
inconsistent = pred != full_pred # probs < full_probs/ood_probs
if 'Neg' in method:
cls_num = logits.shape[1]
scores = full_probs[:, cls_num:].sum(dim=1) # higher score means OOD
else:
scores = - full_probs[xrange, pred]
if 'Hard' in method:
scores[inconsistent] = 0. # higher score means OOD
elif 'Soft' in method:
# negative score adding a positive delta brings a higer score
scores += full_probs[xrange, full_pred] - full_probs[xrange, pred] # higher score means OOD
elif 'Scale' in method:
max_id_sim, max_ood_sim = logits.max(dim=1)[0], ood_logits.max(dim=1)[0]
pair_logits = torch.stack((max_id_sim, max_ood_sim), dim=1)
scale = F.softmax(pair_logits, dim=1)[:, :1].clamp(min=0.5)
# scale = F.softmax(pair_logits * 8., dim=1)[:, :1].clamp(min=0.5) / 8.
full_probs = F.softmax(full_logits * scale, dim=1)
scores = - full_probs[xrange, pred]
else:
assert method == 'MCM_Full'
elif 'MCM_Pair' in method:
scores = - msp # higher score means OOD
pair_logits = torch.stack((logits[xrange, pred], ood_logits[xrange, pred]), dim=1) # shape(nb,2)
inconsistent = pair_logits[:, 0] < pair_logits[:, 1] # id_sim < ood_sim
pair_probs = F.softmax(pair_logits, dim=1)
if 'Hard' in method:
scores[inconsistent] = 0. # higher score means OOD
elif 'Soft' in method:
# negative score multiplying a smaller value brings a higher score
scores *= pair_probs[:, 0].clamp(min=0.5) # higher score means OOD
elif 'Scale' in method:
# print(F.softmax(pair_logits, dim=1)[:, :1].detach())
scale = F.softmax(pair_logits, dim=1)[:, :1].clamp(min=0.5) # 498
# scale = F.softmax(pair_logits * 8., dim=1)[:, :1].clamp(min=0.) / 8.
probs = F.softmax(logits * scale, dim=1)
scores = - probs[xrange, pred]
else:
raise NotImplementedError
elif method == 'MaxScore':
pair_logits = torch.stack((logits[xrange, pred], ood_logits[xrange, pred]), dim=1) # shape(nb,2)
pair_probs = F.softmax(pair_logits*10., dim=1)
scores = -pair_probs[:, 0] * logits[xrange, pred]
if ret_near_ood:
return scores, inconsistent
else:
return scores
def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
"""Use high precision for cumsum and check that final value matches sum
Parameters
----------
arr : array-like
To be cumulatively summed as flat
rtol : float
Relative tolerance, see ``np.allclose``
atol : float
Absolute tolerance, see ``np.allclose``
"""
out = np.cumsum(arr, dtype=np.float64)
expected = np.sum(arr, dtype=np.float64)
if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
raise RuntimeError('cumsum was found to be unstable: '
'its last element does not correspond to sum')
return out
def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None):
classes = np.unique(y_true)
if (pos_label is None and
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1]))):
raise ValueError("Data is not binary and pos_label is not specified")
elif pos_label is None:
pos_label = 1.
# make y_true a boolean vector
y_true = (y_true == pos_label)
# sort scores and corresponding truth values
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
y_score = y_score[desc_score_indices]
y_true = y_true[desc_score_indices]
# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
# accumulate the true positives with decreasing threshold
tps = stable_cumsum(y_true)[threshold_idxs]
fps = 1 + threshold_idxs - tps # add one because of zero-based indexing
thresholds = y_score[threshold_idxs]
recall = tps / tps[-1]
last_ind = tps.searchsorted(tps[-1])
sl = slice(last_ind, None, -1) # [last_ind::-1]
recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
cutoff = np.argmin(np.abs(recall - recall_level))
return fps[cutoff] / (np.sum(np.logical_not(y_true))), thresholds[cutoff] # , fps[cutoff]/(fps[cutoff] + tps[cutoff])
def get_measures(_pos, _neg, recall_level=0.95):
pos = np.array(_pos[:]).reshape((-1, 1))
neg = np.array(_neg[:]).reshape((-1, 1))
examples = np.squeeze(np.vstack((pos, neg)))
labels = np.zeros(len(examples), dtype=np.int32)
labels[:len(pos)] += 1
auroc = roc_auc_score(labels, examples)
aupr = average_precision_score(labels, examples)
fpr, thresh = fpr_and_fdr_at_recall(labels, examples, recall_level)
return auroc, aupr, fpr, thresh