ood/posthoc.py (36 lines of code) (raw):
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
dataset = 'im1000'
thresh_dict = {
'im1000': {0.05: [0.5640, -0.5430], 0.10: [0.4189, -0.4080]},
'im100': {0.05: [0.5625, -0.5415], 0.10: [0.4180, -0.4075]},
}
def applyReAct(feature, p=0.10):
pos_thresh, neg_thresh = thresh_dict[dataset][p]
feature[feature < neg_thresh] = neg_thresh
# feature[feature > pos_thresh] = pos_thresh
return feature
try:
feat_mean = torch.load(f'restore/{dataset}-feat_mean.pt').cuda()
feat_std = torch.load(f'restore/{dataset}-feat_std.pt').cuda()
except:
feat_mean, feat_std = None, None
print('Warning: feat_mean and feat_std not found for BATS.')
def applyBATS(feature, lambd=2):
feature = torch.where(feature<(feat_std*lambd+feat_mean),feature,feat_std*lambd+feat_mean)
feature = torch.where(feature>(-feat_std*lambd+feat_mean),feature,-feat_std*lambd+feat_mean)
return feature
def applyASH(feature, prune_ratio=0.05, method='S'): # 0.05, S
s = feature.abs().sum(dim=1, keepdim=True)
n = feature.shape[1]
k = int(round(n * prune_ratio))
v, i = torch.topk(feature, k, dim=1, largest=False) # .abs() for imagenet-100
feature.scatter_(dim=1, index=i, src=v.detach()*0)
if method == 'B':
raise NotImplementedError('ASH-B leads to poor performance')
feature = (feature != 0).type_as(v) * s / (n - k)
elif method == 'S':
s2 = feature.abs().sum(dim=1, keepdim=True)
feature *= s / s2
return feature
if __name__ == '__main__':
pass