in cp_examples/sip_finetune/train_sip.py [0:0]
def fetch_pos_weights(dataset_name, csv, label_list, uncertain_label, nan_label):
if dataset_name == "nih":
pos = [(csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
neg = [(~csv["Finding Labels"].str.contains(lab)).sum() for lab in label_list]
pos_weights = torch.tensor((neg / np.maximum(pos, 1)).astype(np.float))
else:
pos = (csv[label_list] == 1).sum()
neg = (csv[label_list] == 0).sum()
if uncertain_label == 1:
pos = pos + (csv[label_list] == -1).sum()
elif uncertain_label == -1:
neg = neg + (csv[label_list] == -1).sum()
if nan_label == 1:
pos = pos + (csv[label_list].isna()).sum()
elif nan_label == -1:
neg = neg + (csv[label_list].isna()).sum()
pos_weights = torch.tensor((neg / np.maximum(pos, 1)).values.astype(np.float))
return pos_weights