def pr_curve_summary()

in python/mxboard/summary.py [0:0]


def pr_curve_summary(tag, labels, predictions, num_thresholds, weights=None):
    """Outputs a precision-recall curve `Summary` protocol buffer.

    Parameters
    ----------
        tag : str
            A tag attached to the summary. Used by TensorBoard for organization.
        labels : MXNet `NDArray` or `numpy.ndarray`.
            The ground truth values. A tensor of 0/1 values with arbitrary shape.
        predictions : MXNet `NDArray` or `numpy.ndarray`.
            A float32 tensor whose values are in the range `[0, 1]`. Dimensions must
            match those of `labels`.
        num_thresholds : int
            Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for.
            Should be `>= 2`. This value should be a constant integer value, not a tensor
            that stores an integer.
            The thresholds for computing the pr curves are calculated in the following way:
            `width = 1.0 / (num_thresholds - 1),
            thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]`.
        weights : MXNet `NDArray` or `numpy.ndarray`.
            Optional float32 tensor. Individual counts are multiplied by this value.
            This tensor must be either the same shape as or broadcastable to the `labels` tensor.

    Returns
    -------
        A `Summary` protobuf of the pr_curve.
    """
    # num_thresholds > 127 results in failure of creating protobuf,
    # probably a bug of protobuf
    if num_thresholds > 127:
        logging.warning('num_thresholds>127 would result in failure of creating pr_curve protobuf,'
                        ' clipping it at 127')
        num_thresholds = 127
    labels = _make_numpy_array(labels)
    predictions = _make_numpy_array(predictions)
    if weights is not None:
        weights = _make_numpy_array(weights)
    data = _compute_curve(labels, predictions, num_thresholds=num_thresholds, weights=weights)
    pr_curve_plugin_data = PrCurvePluginData(version=0,
                                             num_thresholds=num_thresholds).SerializeToString()
    plugin_data = [SummaryMetadata.PluginData(plugin_name='pr_curves',
                                              content=pr_curve_plugin_data)]
    smd = SummaryMetadata(plugin_data=plugin_data)
    tensor = TensorProto(dtype='DT_FLOAT',
                         float_val=data.reshape(-1).tolist(),
                         tensor_shape=TensorShapeProto(
                             dim=[TensorShapeProto.Dim(size=data.shape[0]),
                                  TensorShapeProto.Dim(size=data.shape[1])]))
    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])