def fetch_pos_weights()

in cp_examples/sip_finetune/train_sip.py [0:0]


def fetch_pos_weights(dataset_name, csv, label_list, uncertain_label, nan_label):
    if dataset_name == "nih":
        pos = [(csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
        neg = [(~csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
        pos_weights = torch.tensor((neg / np.maximum(pos, 1)).astype(np.float))
    else:
        pos = (csv[label_list] == 1).sum()
        neg = (csv[label_list] == 0).sum()

        if uncertain_label == 1:
            pos = pos + (csv[label_list] == -1).sum()
        elif uncertain_label == -1:
            neg = neg + (csv[label_list] == -1).sum()

        if nan_label == 1:
            pos = pos + (csv[label_list].isna()).sum()
        elif nan_label == -1:
            neg = neg + (csv[label_list].isna()).sum()

        pos_weights = torch.tensor((neg / np.maximum(pos, 1)).values.astype(np.float))

    return pos_weights