in src/common/metrics.py [0:0]
def write_error_samples_binary(filepath, samples, input_indexs, input_id_2_names, targets, probs, threshold=0.5,
use_other_diags=False, use_other_operas=False, use_checkin_department=False):
'''
write bad cases of binary classification
:param filepath:
:param samples:
:param input_indexs:
:param input_id_2_names:
:param targets:
:param probs:
:param threshold:
:param use_other_diags:
:param use_other_operas:
:param use_checkin_department:
:return:
'''
with open(filepath, "w") as fp:
writer = csv.writer(fp)
writer.writerow(["score", "y_true", "y_pred", "inputs"])
for i in range(len(targets)):
target = targets[i][0]
if target != 1:
target = 1
prob = probs[i][0]
if prob >= threshold:
pred = 1
else:
pred = 0
score = 1
if target != pred:
score = 0
target_label = "True" if target == 1 else "False"
pred_label = "True" if target == 1 else "False"
sample = samples[i]
if input_id_2_names:
new_sample = []
for idx, input_index in enumerate(input_indexs):
if input_index == 3 and not use_checkin_department:
input_index = 12
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
else:
new_sample = sample
row = [score, target_label, pred_label, new_sample]
writer.writerow(row)