in src/common/metrics.py [0:0]
def write_error_samples_multi_class(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets, probs,
use_other_diags=False, use_other_operas=False, use_checkin_department=False):
'''
write the bad cases of multi-class classification
:param filepath:
:param samples:
:param input_indexs:
:param input_id_2_names:
:param output_id_2_name:
:param targets:
:param probs:
:param use_other_diags:
:param use_other_operas:
:param use_checkin_department:
:return:
'''
targets = np.argmax(targets, axis=1)
preds = np.argmax(probs, axis=1)
with open(filepath, "w") as fp:
writer = csv.writer(fp)
writer.writerow(["score", "y_true", "y_pred", "inputs"])
for i in range(len(targets)):
target = targets[i]
pred = preds[i]
score = 1
if target != pred:
score = 0
if output_id_2_name:
target_label = output_id_2_name[target]
pred_label = output_id_2_name[pred]
else:
target_label = target
pred_label = pred
sample = samples[i]
if input_id_2_names:
new_sample = []
for idx, input_index in enumerate(input_indexs):
if input_index == 3 and not use_checkin_department:
input_index = 12
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
else:
new_sample = sample
row = [score, target_label, pred_label, new_sample]
writer.writerow(row)