def fetch_pos_weights()

in cp_examples/mip_finetune/train_mip.py [0:0]


def fetch_pos_weights(csv, label_list, uncertain_label, nan_label):
    pos = (csv[label_list] == 1).sum()
    neg = (csv[label_list] == 0).sum()

    if uncertain_label == 1:
        pos = pos + (csv[label_list] == -1).sum()
    elif uncertain_label == -1:
        neg = neg + (csv[label_list] == -1).sum()

    if nan_label == 1:
        pos = pos + (csv[label_list].isna()).sum()
    elif nan_label == -1:
        neg = neg + (csv[label_list].isna()).sum()

    pos_weights = torch.tensor((neg / np.maximum(pos, 1)).values.astype(np.float))

    return pos_weights