def metrics_multi_label()

in src/common/multi_label_metrics.py [0:0]


def metrics_multi_label(targets,  probs, threshold=0.5):
    '''
    metrics of multi-label classification
    cal metrics for true matrix to predict probability matrix
    :param targets: true 0-1 indicator matrix (n_samples, n_labels)
    :param probs: probs 0~1 probability matrix (n_samples, n_labels)
    :param thresold: negative-positive threshold
    :return: some metrics
    '''
    targets_relevant = relevant_indexes(targets)
    preds_relevant = relevant_indexes((probs >= threshold).astype(int))
    acc_list = []
    prec_list = []
    recall_list = []
    jaccard_list = []
    f1_list = []
    roc_auc_list = []
    pr_auc_list = []
    for idx in range(targets.shape[0]):
        target_relevant = targets_relevant[idx]
        pred_relevant = preds_relevant[idx]
        target_len = len(target_relevant)
        predict_len = len(pred_relevant)
        union_len = len(set(target_relevant).union(set(pred_relevant)))
        intersection_len = len(set(target_relevant).intersection(set(pred_relevant)))
        if union_len == 0:
            acc_list.append(1.0)
            prec_list.append(1.0)
            recall_list.append(1.0)
            roc_auc_list.append(1.0)
            jaccard_list.append(1.0)
            f1_list.append(1.0)
            pr_auc_list.append(1.0)
        else:
            # acc
            acc = 1.0 - (union_len - intersection_len) / targets.shape[1]
            acc_list.append(acc)

            # precision
            prec = 0.0
            if predict_len > 0:
                prec = intersection_len / predict_len
            prec_list.append(prec)

            # recall
            if target_len > 0:
                recall = intersection_len / target_len
            else:
                recall = 1.0
            recall_list.append(recall)

            # jaccard sim
            jac = intersection_len / union_len
            jaccard_list.append(jac)

            # f1
            if prec + recall == 0:
                f1 = 0.0
            else:
                f1 = 2.0 * prec * recall / (prec + recall)
            f1_list.append(f1)

            # roc_auc
            if len(np.unique(targets[idx, :])) > 1:
                roc_auc = roc_auc_macro(targets[idx, :], probs[idx, :])
                roc_auc_list.append(roc_auc)
                pr_auc = pr_auc_macro(targets[idx, :], probs[idx, :])
                pr_auc_list.append(pr_auc)

    f_max_value, p_max_value, r_max_value, t_max_value, preds_max_value = f_max(targets, probs)
    return {
        "acc": round(sum(acc_list)/len(acc_list), 4) if len(acc_list) > 0 else 0,
        "jaccard": round(sum(jaccard_list)/len(jaccard_list), 4) if len(jaccard_list) > 0 else 0,
        "prec": round(sum(prec_list)/len(prec_list), 4) if len(prec_list) > 0 else 0,
        "recall": round(sum(recall_list)/len(recall_list), 4) if len(recall_list) > 0 else 0,
        "f1": round(sum(f1_list)/len(f1_list), 4) if len(f1_list) > 0 else 0,
        "pr_auc": round(sum(pr_auc_list)/len(pr_auc_list), 4) if len(pr_auc_list) > 0 else 0,
        "roc_auc": round(sum(roc_auc_list)/len(roc_auc_list), 4) if len(roc_auc_list) > 0 else 0,
        "fmax": f_max_value,
        "pmax": p_max_value,
        "rmax": r_max_value,
        "tmax": t_max_value
    }