models/base.py (289 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import math from typing import Tuple, Optional from utils.loss_fn import my_cl_loss_fn3, stable_imbce from models.feat_pool import IDFeatPool class BaseModel(nn.Module): def __init__(self): super(BaseModel, self).__init__() self.return_features = False self.aux_linear = None self.projection = None self.lambda_linear = None self.id_feat_pool: Optional["IDFeatPool"] = None self.num_classes = 1 self.penultimate_layer_dim = 1 def build_aux_layers(self): self.aux_linear = nn.Linear(1, 1) self.projection = nn.Sequential( nn.Linear(self.penultimate_layer_dim, self.penultimate_layer_dim), nn.ReLU(), nn.Linear(self.penultimate_layer_dim, 128) ) self.lambda_linear = nn.Linear(self.penultimate_layer_dim, self.num_classes) self.lambda_linear.bias.data.fill_(0.0) def forward_features(self, x): raise NotImplementedError def forward_classifier(self, p4): raise NotImplementedError def forward_aux_classifier(self, x): return self.aux_linear(x) # (11) def forward_lambda(self, x, _prob=None, eps=1e-4) -> torch.Tensor: lambd = self.lambda_linear(x).exp() # _min, _max = eps/(_prob+eps), (1+eps)/(_prob+eps) # lambd = torch.minimum(torch.maximum(lambd, _min), _max) return lambd.squeeze() def forward_projection(self, p4): projected_f = self.projection(p4) # (10) projected_f = F.normalize(projected_f, dim=1) return projected_f def forward(self, x, mode='forward_only', **kwargs): p4 = self.forward_features(x) logits = self.forward_classifier(p4) ood_p4 = kwargs.pop('ood_data', None) if ood_p4 is not None: ood_logits = self.forward_classifier(ood_p4) logits = torch.cat((logits, ood_logits), dim=0) p4 = torch.cat((p4, ood_p4), dim=0) return_features = kwargs.pop('return_features', False) or self.return_features if mode == 'forward_only': return (logits, p4) if return_features else logits elif mode == 'calc_loss': res = self.calc_loss(logits, p4, **kwargs) ret_p4 = p4 if ood_p4 is None else p4[:-len(ood_p4)] return (*res, ret_p4) if return_features else res else: raise NotImplementedError(mode) def calc_loss(self, logits, p4, labels, adjustments, args, use_imood=True): in_labels = torch.cat([labels, labels], dim=0) num_sample, total_num_in = logits.shape[0], in_labels.shape[0] assert num_sample > total_num_in device = in_labels.device metric = args.ood_metric in_sample_in_logits, in_sample_ood_logits, ood_sample_in_logits, ood_sample_ood_logits \ = self.parse_logits(logits, p4, metric, total_num_in) in_loss, ood_loss, aux_ood_loss = \ torch.zeros((1,), device=device), torch.zeros((1,), device=device), torch.zeros((1,), device=device) if not metric.startswith('ada_'): in_loss += F.cross_entropy(in_sample_in_logits + adjustments, in_labels) if metric == 'oe': ood_loss += -(ood_sample_ood_logits.mean(1) - torch.logsumexp(ood_sample_ood_logits, dim=1)).mean() elif metric == 'energy': Ec_out = -torch.logsumexp(ood_sample_ood_logits, dim=1) Ec_in = -torch.logsumexp(in_sample_ood_logits, dim=1) m_in, m_out = -23 if self.num_classes == 10 else -27, -5 # cifar10/100 # 0.2 * 0.5 = 0.1, the default loss scale in official Energy OOD ood_loss += (torch.pow(F.relu(Ec_in-m_in), 2).mean() + torch.pow(F.relu(m_out-Ec_out), 2).mean()) * 0.2 elif metric == 'bkg_c': ood_labels = torch.full_like(in_labels[:1], self.num_classes) ood_loss += F.cross_entropy(ood_sample_ood_logits, ood_labels) elif metric == 'bin_disc': ood_labels = torch.zeros((num_sample,), device=device) ood_labels[:total_num_in] = 1. ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0).squeeze(1) ood_loss += F.binary_cross_entropy_with_logits(ood_logits, ood_labels) elif metric == 'mc_disc': ood_labels = torch.zeros((num_sample,), device=device) ood_labels[:total_num_in] = 1. # id: cls0; ood: cls1 ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0) ood_loss += F.cross_entropy(ood_logits, ood_labels) else: raise NotImplementedError(metric) else: ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0) ood_logits = self.parse_ada_ood_logits(ood_logits, metric) ood_labels = torch.zeros((num_sample,), device=device) ood_labels[:total_num_in] = 1. cls_prior = F.softmax(adjustments, dim=1) min_thresh = 1e-4 lambd = self.forward_lambda(p4).squeeze().clamp(min=min_thresh) smoothing = 0.2 m_in_labels: torch.Tensor = F.one_hot(in_labels, num_classes=self.num_classes) in_posterior = m_in_labels * (1 - smoothing) + smoothing / self.num_classes ood_posterior = F.softmax(ood_sample_in_logits.detach(), dim=1) cls_posterior = torch.cat((in_posterior, ood_posterior)) beta = (lambd * cls_posterior / cls_prior).mean(dim=1) #.clamp(min=1e-1, max=1e+1) ood_loss += (beta.log() + ood_logits.detach().sigmoid().log()).relu().mean() beta = beta.detach() delta = (beta + (beta - 1.) * torch.exp(ood_logits.detach())).clamp(min=1e-1, max=1e+1) delta = torch.cat((delta[:total_num_in].clamp(min=1.), delta[total_num_in:].clamp(max=1.)), dim=0) ood_logits = ood_logits - delta.log() ood_loss += F.binary_cross_entropy_with_logits(ood_logits, ood_labels) if metric == 'ada_oe': # add original OE loss ood_loss += -(ood_sample_ood_logits.mean(1) - torch.logsumexp(ood_sample_ood_logits, dim=1)).mean() in_sample_in_logits = in_sample_in_logits + adjustments in_loss += F.cross_entropy(in_sample_in_logits, in_labels) aux_ood_loss += self.calc_aux_loss(p4, labels, args) return in_loss, ood_loss, aux_ood_loss def calc_aux_loss(self, p4, labels, args): device = p4.device aux_loss = torch.zeros(1, device=device) num_in = labels.shape[0] if 'pascl' == args.aux_ood_loss: if not hasattr(self, 'cl_loss_weights'): _sigmoid_x = torch.linspace(-1, 1, self.num_classes).to(device) _d = -2 * args.k + 1 - 0.001 # - 0.001 to make _d<-1 when k=1 self.register_buffer('cl_loss_weights', torch.sign((_sigmoid_x-_d))) tail_idx = labels >= round((1-args.k)*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1 all_f = self.forward_projection(p4) f_id_view1, f_id_view2 = all_f[0:num_in], all_f[num_in:2*num_in] f_id_tail_view1 = f_id_view1[tail_idx] # i.e., 6,7,8,9 in cifar10 f_id_tail_view2 = f_id_view2[tail_idx] # i.e., 6,7,8,9 in cifar10 labels_tail = labels[tail_idx] f_ood = all_f[2*num_in:] if torch.sum(tail_idx) > 0: aux_loss += my_cl_loss_fn3( torch.stack((f_id_tail_view1, f_id_tail_view2), dim=1), f_ood, labels_tail, temperature=args.T, reweighting=True, w_list=self.cl_loss_weights ) elif 'simclr' == args.aux_ood_loss: ood_logits = self.projection(p4)[:, 0] ood_labels = torch.zeros((len(p4),), device=device) assert len(p4) > num_in*2 ood_labels[:num_in*2] = 1. aux_loss = F.binary_cross_entropy_with_logits(ood_logits, ood_labels) elif 'cocl' == args.aux_ood_loss: from utils.loss_fn import compute_dist Lambda2, Lambda3 = 0.05, 0.1 temperature, margin = 0.07, 1.0 headrate, tailrate = 0.4, 0.4 f_id_view = p4[:2*num_in] f_ood = p4[2*num_in:] head_idx = labels<= round(headrate*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1 tail_idx = labels>= round((1-tailrate)*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1 f_id_head_view = f_id_view[head_idx] # i.e., 6,7,8,9 in cifar10 f_id_tail_view = f_id_view[tail_idx] # i.e., 6,7,8,9 in cifar10 labels_tail = labels[tail_idx] # OOD-aware tail class prototype learning if len(f_id_tail_view) > 0 and Lambda2 > 0: ## TODO raise NotImplementedError logits = self.forward_weight(f_id_tail_view, f_ood, temperature=temperature) tail_loss = F.cross_entropy(logits, labels_tail-round((1-tailrate)*self.num_classes)) else: tail_loss = torch.zeros((1, ), device=device) # debiased head class learning if Lambda3 > 0: dist1 = compute_dist(f_ood, f_ood) _, dist_max1 = torch.max(dist1, 1) positive = f_ood[dist_max1] dist2 = torch.randint(low = 0, high= len(f_id_head_view), size = (1, len(f_ood))).to(device).squeeze() negative = f_id_head_view[dist2] triplet_loss = torch.nn.TripletMarginLoss(margin=margin) head_loss = triplet_loss(f_ood, positive, negative) else: head_loss = torch.zeros((1, ), device=device) aux_loss = tail_loss * Lambda2 + head_loss * Lambda3 return aux_loss def parse_logits(self, all_logits, all_features, metric, num_in) \ -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if any(x in metric for x in ['bin_disc']): in_sample_in_logits = all_logits[:num_in, :-1] in_sample_ood_logits = all_logits[:num_in, -1:] ood_sample_in_logits = all_logits[num_in:, :-1] ood_sample_ood_logits = all_logits[num_in:, -1:] elif any(x in metric for x in ['mc_disc']): in_sample_in_logits = all_logits[:num_in, :-2] in_sample_ood_logits = all_logits[:num_in, -2:] ood_sample_in_logits = all_logits[num_in:, :-2] ood_sample_ood_logits = all_logits[num_in:, -2:] elif any(x in metric for x in ['msp', 'oe', 'bkg_c', 'energy']): in_sample_in_logits = all_logits[:num_in, :] in_sample_ood_logits = all_logits[:num_in, :] ood_sample_in_logits = all_logits[num_in:, :] ood_sample_ood_logits = all_logits[num_in:, :] elif any(x in metric for x in ['gradnorm']): in_sample_in_logits = all_logits[:num_in, :] in_sample_ood_logits = all_features[:num_in, :] ood_sample_in_logits = all_logits[num_in:, :] ood_sample_ood_logits = all_features[num_in:, :] elif any(x in metric for x in ['maha']): all_maha_scores = self.calc_maha_score(all_features) in_sample_in_logits = all_logits[:num_in, :] in_sample_ood_logits = all_maha_scores[:num_in, None] ood_sample_in_logits = all_logits[num_in:, :] ood_sample_ood_logits = all_maha_scores[num_in:, None] else: raise NotImplementedError('parse_logits %s' % metric) return in_sample_in_logits, in_sample_ood_logits, ood_sample_in_logits, ood_sample_ood_logits def parse_ada_ood_logits(self, ood_logits, metric, project=True): if any(x in metric for x in ['bin_disc']): pass else: if any(x in metric for x in ['msp', 'oe']): ood_logits = F.softmax(ood_logits, dim=1).max(dim=1, keepdim=True).values - 1. / self.num_classes # MSP elif any(x in metric for x in ['energy']): ood_logits = torch.logsumexp(ood_logits, dim=1, keepdim=True) elif any(x in metric for x in ['maha']): pass # already calculated elif any(x in metric for x in ['gradnorm']): ood_logits = [self.calc_gradnorm_per_sample(f) for f in ood_logits] ood_logits = torch.tensor(ood_logits).view(-1, 1).cuda() else: raise NotImplementedError(metric) if project: ood_logits = self.forward_aux_classifier(ood_logits) return ood_logits.squeeze(1) def calc_maha_score(self, features): assert self.id_feat_pool is not None if self.training and not self.id_feat_pool.ready(): return torch.zeros(len(features), device=features.device) return self.id_feat_pool.calc_maha_score(features, force_calc=self.training) def calc_gradnorm_per_sample(self, features, targets=None, temperature=1.): assert len(features.shape) == 1 self.requires_grad_(True) features = features.view(1, -1) features = Variable(features.cuda(), requires_grad=True) self.zero_grad() outputs = self.forward_classifier(features) / temperature if targets is None: targets = torch.ones((1, self.num_classes)).cuda() / self.num_classes kl_loss = F.kl_div(outputs.softmax(dim=-1).log(), targets.softmax(dim=-1), reduction='sum') kl_loss.backward() layer_grad = self.linear.weight.grad.data gradnorm = torch.sum(torch.abs(layer_grad)) self.requires_grad_(False) return gradnorm def get_ood_scores(model: BaseModel, images, metric, adjustments): logits, features = model(images, return_features=True) in_logits, ood_logits = model.parse_logits(logits, features, metric, logits.shape[0])[:2] if metric.startswith('ada_'): ood_logits = model.parse_ada_ood_logits(ood_logits, metric, project=False) prior = F.softmax(adjustments, dim=1) posterior = F.softmax(in_logits, dim=1) out_adjust = (posterior / prior).mean(dim=1).log() # 1.0 for bin_disc, 0.1 for msp, 1.0 for energy, 0.02 for pascl, 0.01 for maha scale_dict = {'ada_msp': 0.1, 'ada_energy': 1.0, 'ada_bin_disc': 1.0, 'ada_maha': 0.01, 'ada_gradnorm': 10} ood_logits += out_adjust * scale_dict[metric] scores = - ood_logits else: prior = F.softmax(adjustments, dim=1) posterior = F.softmax(in_logits, dim=1) if metric == 'msp': # The larger MSP, the smaller uncertainty scores = - F.softmax(logits, dim=1).max(dim=1).values elif metric == 'energy': # The larger energy, the smaller uncertainty tau = 1. scores = - tau * torch.logsumexp(logits / tau, dim=1) elif metric == 'bkg_c': # The larger softmax background-class prob, the larger uncertainty scores = F.softmax(ood_logits, dim=1)[:, -1] elif metric == 'bin_disc': # The larger sigmoid prob, the smaller uncertainty scores = 1. - ood_logits.squeeze(1).sigmoid() elif metric == 'mc_disc': # The larger softmax prob, the smaller uncertainty scores = F.softmax(ood_logits, dim=1)[:, 1] elif metric == 'rp_msp': # The larger MSP, the smaller uncertainty scores = - (F.softmax(logits, dim=1) - .01 * F.softmax(adjustments, dim=1)).max(dim=1).values elif metric == 'rp_gradnorm': # The larger GradNorm, the smaller uncertainty prior = F.softmax(adjustments, dim=1) scores = [model.calc_gradnorm_per_sample(feat, targets=prior) for feat in features] scores = - torch.tensor(scores) elif metric == 'gradnorm': # The larger GradNorm, the smaller uncertainty scores = [model.calc_gradnorm_per_sample(feat) for feat in features] scores = - torch.tensor(scores) elif metric == 'rw_energy': # The larger energy, the smaller uncertainty tau = 1. prior = F.softmax(adjustments, dim=1) posterior = F.softmax(logits, dim=1) rweight = 1. - (posterior * prior).sum(dim=1) / (posterior.norm(2, dim=1) * prior.norm(2, dim=1)) scores = - tau * torch.logsumexp(logits / tau, dim=1) * rweight elif metric == 'maha': scores = - ood_logits[:, 0] # already calculated else: raise NotImplementedError('OOD inference metric: ', metric) return in_logits, scores