in src/common/multi_label_metrics.py [0:0]
def write_error_samples_multi_label(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets,
probs, threshold=0.5,
use_other_diags=False, use_other_operas=False, use_checkin_department=False):
'''
writer bad cases for multi-label classification
:param filepath:
:param samples:
:param input_indexs:
:param input_id_2_names:
:param output_id_2_name:
:param targets:
:param probs:
:param threshold:
:param use_other_diags:
:param use_other_operas:
:param use_checkin_department:
:return:
'''
preds = prob_2_pred(probs, threshold=threshold)
targets_relevant = relevant_indexes(targets)
preds_relevant = relevant_indexes(preds)
with open(filepath, "w") as fp:
writer = csv.writer(fp)
writer.writerow(["score", "y_true", "y_pred", "inputs"])
for i in range(len(targets_relevant)):
target = set(targets_relevant[i])
pred = set(preds_relevant[i])
jacc = len(target.intersection(pred))/(len(target.union(pred)))
if output_id_2_name:
target_labels = [output_id_2_name[v] for v in target]
pred_labels = [output_id_2_name[v] for v in pred]
else:
target_labels = target
pred_labels = pred
sample = samples[i]
if input_id_2_names:
new_sample = []
for idx, input_index in enumerate(input_indexs):
if input_index == 3 and not use_checkin_department:
input_index = 12
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
if input_index == 6 and use_other_diags or input_index == 8 and use_other_operas or input_index == 10 and use_other_diags:
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
else:
new_sample = sample
row = [jacc, target_labels, pred_labels, new_sample]
writer.writerow(row)