def build_prior()

in utils/common.py [0:0]


def build_prior(args, model, img_num_per_cls, num_classes, num_outputs, device):
    img_num_per_cls = torch.from_numpy(img_num_per_cls).to(device)
    if args.logit_adjust > 0:
        adjustments = img_num_per_cls / img_num_per_cls.sum()
        adjustments = args.logit_adjust * torch.log(adjustments + 1e-12)[None, :]
        if args.ood_metric in ['bkg_c'] and adjustments.shape[1] != num_outputs:
            placeholder = torch.zeros_like(adjustments[:, :num_outputs - num_classes])
            adjustments = torch.cat((adjustments, placeholder), dim=1)
    else:
        if args.ood_metric in ['bkg_c']:
            adjustments = torch.zeros((1, num_outputs), device=device)
        else:
            adjustments = torch.zeros((1, num_classes), device=device)

    return adjustments