in models/base.py [0:0]
def calc_loss(self, logits, p4, labels, adjustments, args, use_imood=True):
in_labels = torch.cat([labels, labels], dim=0)
num_sample, total_num_in = logits.shape[0], in_labels.shape[0]
assert num_sample > total_num_in
device = in_labels.device
metric = args.ood_metric
in_sample_in_logits, in_sample_ood_logits, ood_sample_in_logits, ood_sample_ood_logits \
= self.parse_logits(logits, p4, metric, total_num_in)
in_loss, ood_loss, aux_ood_loss = \
torch.zeros((1,), device=device), torch.zeros((1,), device=device), torch.zeros((1,), device=device)
if not metric.startswith('ada_'):
in_loss += F.cross_entropy(in_sample_in_logits + adjustments, in_labels)
if metric == 'oe':
ood_loss += -(ood_sample_ood_logits.mean(1) - torch.logsumexp(ood_sample_ood_logits, dim=1)).mean()
elif metric == 'energy':
Ec_out = -torch.logsumexp(ood_sample_ood_logits, dim=1)
Ec_in = -torch.logsumexp(in_sample_ood_logits, dim=1)
m_in, m_out = -23 if self.num_classes == 10 else -27, -5 # cifar10/100
# 0.2 * 0.5 = 0.1, the default loss scale in official Energy OOD
ood_loss += (torch.pow(F.relu(Ec_in-m_in), 2).mean() + torch.pow(F.relu(m_out-Ec_out), 2).mean()) * 0.2
elif metric == 'bkg_c':
ood_labels = torch.full_like(in_labels[:1], self.num_classes)
ood_loss += F.cross_entropy(ood_sample_ood_logits, ood_labels)
elif metric == 'bin_disc':
ood_labels = torch.zeros((num_sample,), device=device)
ood_labels[:total_num_in] = 1.
ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0).squeeze(1)
ood_loss += F.binary_cross_entropy_with_logits(ood_logits, ood_labels)
elif metric == 'mc_disc':
ood_labels = torch.zeros((num_sample,), device=device)
ood_labels[:total_num_in] = 1. # id: cls0; ood: cls1
ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0)
ood_loss += F.cross_entropy(ood_logits, ood_labels)
else:
raise NotImplementedError(metric)
else:
ood_logits = torch.cat((in_sample_ood_logits, ood_sample_ood_logits), dim=0)
ood_logits = self.parse_ada_ood_logits(ood_logits, metric)
ood_labels = torch.zeros((num_sample,), device=device)
ood_labels[:total_num_in] = 1.
cls_prior = F.softmax(adjustments, dim=1)
min_thresh = 1e-4
lambd = self.forward_lambda(p4).squeeze().clamp(min=min_thresh)
smoothing = 0.2
m_in_labels: torch.Tensor = F.one_hot(in_labels, num_classes=self.num_classes)
in_posterior = m_in_labels * (1 - smoothing) + smoothing / self.num_classes
ood_posterior = F.softmax(ood_sample_in_logits.detach(), dim=1)
cls_posterior = torch.cat((in_posterior, ood_posterior))
beta = (lambd * cls_posterior / cls_prior).mean(dim=1) #.clamp(min=1e-1, max=1e+1)
ood_loss += (beta.log() + ood_logits.detach().sigmoid().log()).relu().mean()
beta = beta.detach()
delta = (beta + (beta - 1.) * torch.exp(ood_logits.detach())).clamp(min=1e-1, max=1e+1)
delta = torch.cat((delta[:total_num_in].clamp(min=1.),
delta[total_num_in:].clamp(max=1.)), dim=0)
ood_logits = ood_logits - delta.log()
ood_loss += F.binary_cross_entropy_with_logits(ood_logits, ood_labels)
if metric == 'ada_oe': # add original OE loss
ood_loss += -(ood_sample_ood_logits.mean(1) - torch.logsumexp(ood_sample_ood_logits, dim=1)).mean()
in_sample_in_logits = in_sample_in_logits + adjustments
in_loss += F.cross_entropy(in_sample_in_logits, in_labels)
aux_ood_loss += self.calc_aux_loss(p4, labels, args)
return in_loss, ood_loss, aux_ood_loss