def write_error_samples_binary()

in src/common/metrics.py [0:0]


def write_error_samples_binary(filepath, samples, input_indexs, input_id_2_names, targets, probs, threshold=0.5,
                               use_other_diags=False, use_other_operas=False, use_checkin_department=False):
    '''
    write bad cases of binary classification
    :param filepath:
    :param samples:
    :param input_indexs:
    :param input_id_2_names:
    :param targets:
    :param probs:
    :param threshold:
    :param use_other_diags:
    :param use_other_operas:
    :param use_checkin_department:
    :return:
    '''
    with open(filepath, "w") as fp:
        writer = csv.writer(fp)
        writer.writerow(["score", "y_true", "y_pred", "inputs"])
        for i in range(len(targets)):
            target = targets[i][0]
            if target != 1:
                target = 1
            prob = probs[i][0]
            if prob >= threshold:
                pred = 1
            else:
                pred = 0
            score = 1
            if target != pred:
                score = 0
            target_label = "True" if target == 1 else "False"
            pred_label = "True" if target == 1 else "False"
            sample = samples[i]
            if input_id_2_names:
                new_sample = []
                for idx, input_index in enumerate(input_indexs):
                    if input_index == 3 and not use_checkin_department:
                        input_index = 12
                    new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
                    if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
                        new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
            else:
                new_sample = sample
            row = [score, target_label, pred_label, new_sample]
            writer.writerow(row)