in utils/common.py [0:0]
def build_prior(args, model, img_num_per_cls, num_classes, num_outputs, device):
img_num_per_cls = torch.from_numpy(img_num_per_cls).to(device)
if args.logit_adjust > 0:
adjustments = img_num_per_cls / img_num_per_cls.sum()
adjustments = args.logit_adjust * torch.log(adjustments + 1e-12)[None, :]
if args.ood_metric in ['bkg_c'] and adjustments.shape[1] != num_outputs:
placeholder = torch.zeros_like(adjustments[:, :num_outputs - num_classes])
adjustments = torch.cat((adjustments, placeholder), dim=1)
else:
if args.ood_metric in ['bkg_c']:
adjustments = torch.zeros((1, num_outputs), device=device)
else:
adjustments = torch.zeros((1, num_classes), device=device)
return adjustments