ood/losses.py (11 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn.functional as F class LogitNormLoss(torch.nn.Module): def __init__(self, t=.07): super(LogitNormLoss, self).__init__() self.t = t def forward(self, x, target): norms = torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-7 logit_norm = torch.div(x, norms) / self.t loss = F.cross_entropy(logit_norm, target) return loss #* 50.