def plot_pr_curve_multiclass()

in distributed_training/util/inference_utils.py [0:0]


def plot_pr_curve_multiclass(y_true_ohe, y_score, num_classes, color_table, skip_legend=5, is_single_fig=False):
    """
    Plot precision-recall curve to multi-class
    """      
    precision = dict()
    recall = dict()
    average_precision = dict()
    for i in range(num_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_true_ohe[:, i], y_score[:, i])
        average_precision[i] = average_precision_score(y_true_ohe[:, i], y_score[:, i])

    # A "micro-average": quantifying score on all classes jointly
    precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_ohe.ravel(), y_score.ravel())
    average_precision["micro"] = average_precision_score(y_true_ohe, y_score, average="micro")
    average_precision["macro"] = average_precision_score(y_true_ohe, y_score, average="macro")

    all_precision = np.unique(np.concatenate([precision[i] for i in range(num_classes)]))

    # Then interpolate all ROC curves at this points
    mean_recall = np.zeros_like(all_precision)
    for i in range(num_classes):
        mean_recall += np.interp(all_precision, precision[i], recall[i])

    # Finally average it and compute AUC
    mean_recall /= num_classes

    precision["macro"] = all_precision
    recall["macro"] = mean_recall

    colors = cycle(color_table)
    fig, ax = plt.subplots(figsize=(8,8))    
    label = 'micro-average Precision-recall (area = {0:0.4f})'.format(average_precision["micro"])
    ax.plot(recall["micro"], precision["micro"], label=label, color='deeppink', lw=3)

    label = 'macro-average Precision-recall (area = {0:0.4f})'.format(average_precision["macro"])
    ax.plot(recall["macro"], precision["macro"], label=label, color='navy', lw=3)

    for i, color in zip(range(num_classes), colors):
        if i % skip_legend == 0:
            label = 'PR for class {0} (area = {1:0.4f})'.format(i, average_precision[i])
        else:
            label = None
        ax.plot(recall[i], precision[i], color=color, label=label, lw=2, alpha=0.5, linestyle=':')  

    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_title('Precision-Recall curve to multi-class')
    ax.legend(loc="lower left", prop={'size':10})
   
    if is_single_fig:
        plt.show()