def write_error_samples_multi_class()

in src/common/metrics.py [0:0]


def write_error_samples_multi_class(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets, probs,
                                    use_other_diags=False, use_other_operas=False, use_checkin_department=False):
    '''
    write the bad cases of multi-class classification
    :param filepath:
    :param samples:
    :param input_indexs:
    :param input_id_2_names:
    :param output_id_2_name:
    :param targets:
    :param probs:
    :param use_other_diags:
    :param use_other_operas:
    :param use_checkin_department:
    :return:
    '''
    targets = np.argmax(targets, axis=1)
    preds = np.argmax(probs, axis=1)
    with open(filepath, "w") as fp:
        writer = csv.writer(fp)
        writer.writerow(["score", "y_true", "y_pred", "inputs"])
        for i in range(len(targets)):
            target = targets[i]
            pred = preds[i]
            score = 1
            if target != pred:
                score = 0
            if output_id_2_name:
                target_label = output_id_2_name[target]
                pred_label = output_id_2_name[pred]
            else:
                target_label = target
                pred_label = pred
            sample = samples[i]
            if input_id_2_names:
                new_sample = []
                for idx, input_index in enumerate(input_indexs):
                    if input_index == 3 and not use_checkin_department:
                        input_index = 12
                    new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
                    if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
                        new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
            else:
                new_sample = sample
            row = [score, target_label, pred_label, new_sample]
            writer.writerow(row)