in distributed_training/util/inference_utils.py [0:0]
def plot_pr_curve_multiclass(y_true_ohe, y_score, num_classes, color_table, skip_legend=5, is_single_fig=False):
"""
Plot precision-recall curve to multi-class
"""
precision = dict()
recall = dict()
average_precision = dict()
for i in range(num_classes):
precision[i], recall[i], _ = precision_recall_curve(y_true_ohe[:, i], y_score[:, i])
average_precision[i] = average_precision_score(y_true_ohe[:, i], y_score[:, i])
# A "micro-average": quantifying score on all classes jointly
precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_ohe.ravel(), y_score.ravel())
average_precision["micro"] = average_precision_score(y_true_ohe, y_score, average="micro")
average_precision["macro"] = average_precision_score(y_true_ohe, y_score, average="macro")
all_precision = np.unique(np.concatenate([precision[i] for i in range(num_classes)]))
# Then interpolate all ROC curves at this points
mean_recall = np.zeros_like(all_precision)
for i in range(num_classes):
mean_recall += np.interp(all_precision, precision[i], recall[i])
# Finally average it and compute AUC
mean_recall /= num_classes
precision["macro"] = all_precision
recall["macro"] = mean_recall
colors = cycle(color_table)
fig, ax = plt.subplots(figsize=(8,8))
label = 'micro-average Precision-recall (area = {0:0.4f})'.format(average_precision["micro"])
ax.plot(recall["micro"], precision["micro"], label=label, color='deeppink', lw=3)
label = 'macro-average Precision-recall (area = {0:0.4f})'.format(average_precision["macro"])
ax.plot(recall["macro"], precision["macro"], label=label, color='navy', lw=3)
for i, color in zip(range(num_classes), colors):
if i % skip_legend == 0:
label = 'PR for class {0} (area = {1:0.4f})'.format(i, average_precision[i])
else:
label = None
ax.plot(recall[i], precision[i], color=color, label=label, lw=2, alpha=0.5, linestyle=':')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall curve to multi-class')
ax.legend(loc="lower left", prop={'size':10})
if is_single_fig:
plt.show()