in models/base.py [0:0]
def calc_aux_loss(self, p4, labels, args):
device = p4.device
aux_loss = torch.zeros(1, device=device)
num_in = labels.shape[0]
if 'pascl' == args.aux_ood_loss:
if not hasattr(self, 'cl_loss_weights'):
_sigmoid_x = torch.linspace(-1, 1, self.num_classes).to(device)
_d = -2 * args.k + 1 - 0.001 # - 0.001 to make _d<-1 when k=1
self.register_buffer('cl_loss_weights', torch.sign((_sigmoid_x-_d)))
tail_idx = labels >= round((1-args.k)*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1
all_f = self.forward_projection(p4)
f_id_view1, f_id_view2 = all_f[0:num_in], all_f[num_in:2*num_in]
f_id_tail_view1 = f_id_view1[tail_idx] # i.e., 6,7,8,9 in cifar10
f_id_tail_view2 = f_id_view2[tail_idx] # i.e., 6,7,8,9 in cifar10
labels_tail = labels[tail_idx]
f_ood = all_f[2*num_in:]
if torch.sum(tail_idx) > 0:
aux_loss += my_cl_loss_fn3(
torch.stack((f_id_tail_view1, f_id_tail_view2), dim=1), f_ood, labels_tail, temperature=args.T,
reweighting=True, w_list=self.cl_loss_weights
)
elif 'simclr' == args.aux_ood_loss:
ood_logits = self.projection(p4)[:, 0]
ood_labels = torch.zeros((len(p4),), device=device)
assert len(p4) > num_in*2
ood_labels[:num_in*2] = 1.
aux_loss = F.binary_cross_entropy_with_logits(ood_logits, ood_labels)
elif 'cocl' == args.aux_ood_loss:
from utils.loss_fn import compute_dist
Lambda2, Lambda3 = 0.05, 0.1
temperature, margin = 0.07, 1.0
headrate, tailrate = 0.4, 0.4
f_id_view = p4[:2*num_in]
f_ood = p4[2*num_in:]
head_idx = labels<= round(headrate*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1
tail_idx = labels>= round((1-tailrate)*self.num_classes) # dont use int! since 1-0.9=0.0999!=0.1
f_id_head_view = f_id_view[head_idx] # i.e., 6,7,8,9 in cifar10
f_id_tail_view = f_id_view[tail_idx] # i.e., 6,7,8,9 in cifar10
labels_tail = labels[tail_idx]
# OOD-aware tail class prototype learning
if len(f_id_tail_view) > 0 and Lambda2 > 0:
## TODO
raise NotImplementedError
logits = self.forward_weight(f_id_tail_view, f_ood, temperature=temperature)
tail_loss = F.cross_entropy(logits, labels_tail-round((1-tailrate)*self.num_classes))
else:
tail_loss = torch.zeros((1, ), device=device)
# debiased head class learning
if Lambda3 > 0:
dist1 = compute_dist(f_ood, f_ood)
_, dist_max1 = torch.max(dist1, 1)
positive = f_ood[dist_max1]
dist2 = torch.randint(low = 0, high= len(f_id_head_view), size = (1, len(f_ood))).to(device).squeeze()
negative = f_id_head_view[dist2]
triplet_loss = torch.nn.TripletMarginLoss(margin=margin)
head_loss = triplet_loss(f_ood, positive, negative)
else:
head_loss = torch.zeros((1, ), device=device)
aux_loss = tail_loss * Lambda2 + head_loss * Lambda3
return aux_loss