def write_error_samples_multi_label()

in src/common/multi_label_metrics.py [0:0]


def write_error_samples_multi_label(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets,
                                    probs, threshold=0.5,
                                    use_other_diags=False, use_other_operas=False, use_checkin_department=False):
    '''
    writer bad cases for multi-label classification
    :param filepath:
    :param samples:
    :param input_indexs:
    :param input_id_2_names:
    :param output_id_2_name:
    :param targets:
    :param probs:
    :param threshold:
    :param use_other_diags:
    :param use_other_operas:
    :param use_checkin_department:
    :return:
    '''
    preds = prob_2_pred(probs, threshold=threshold)
    targets_relevant = relevant_indexes(targets)
    preds_relevant = relevant_indexes(preds)
    with open(filepath, "w") as fp:
        writer = csv.writer(fp)
        writer.writerow(["score", "y_true", "y_pred", "inputs"])
        for i in range(len(targets_relevant)):
            target = set(targets_relevant[i])
            pred = set(preds_relevant[i])
            jacc = len(target.intersection(pred))/(len(target.union(pred)))
            if output_id_2_name:
                target_labels = [output_id_2_name[v] for v in target]
                pred_labels = [output_id_2_name[v] for v in pred]
            else:
                target_labels = target
                pred_labels = pred
            sample = samples[i]
            if input_id_2_names:
                new_sample = []
                for idx, input_index in enumerate(input_indexs):
                    if input_index == 3 and not use_checkin_department:
                        input_index = 12
                    new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
                    if input_index == 6 and use_other_diags or input_index == 8 and use_other_operas or input_index == 10 and use_other_diags:
                        new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
            else:
                new_sample = sample
            row = [jacc, target_labels, pred_labels, new_sample]
            writer.writerow(row)