def plot_precision_recall_curve()

in builtin_algorithm_hpo_tabular/util/classification_report.py [0:0]


def plot_precision_recall_curve(y_real,
                                y_predict,
                                axis=None,
                                plot_style='ggplot'):
    
    if axis is None:  # for standalone plot
        plt.figure()
        ax = plt.gca()
    else:  # for plots inside a subplot
        ax = axis

    plt.style.use(plot_style)
    
    metrics_P, metrics_R, _ = metrics.precision_recall_curve(y_real, y_predict)
    metrics_AP = metrics.average_precision_score(y_real, y_predict)
    
    ax.set_aspect(aspect=0.95)
    ax.step(metrics_R, metrics_P, color='b', where='post', linewidth=0.7)
    ax.fill_between(metrics_R, metrics_P, step='post', alpha=0.2, color='b')
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_ylim([0.0, 1.05])
    ax.set_xlim([0.0, 1.05])
    ax.set_title('Precision-Recall curve: AP={0:0.3f}'.format(metrics_AP))
    
    if axis is None:  # for standalone plots
        plt.tight_layout()
        plt.show()