in tensorflow_model_analysis/metrics/binary_confusion_matrices.py [0:0]
def binary_confusion_matrices(
num_thresholds: Optional[int] = None,
thresholds: Optional[List[float]] = None,
name: Optional[str] = None,
eval_config: Optional[config_pb2.EvalConfig] = None,
model_name: str = '',
output_name: str = '',
sub_key: Optional[metric_types.SubKey] = None,
aggregation_type: Optional[metric_types.AggregationType] = None,
class_weights: Optional[Dict[int, float]] = None,
example_weighted: bool = False,
use_histogram: Optional[bool] = None,
extract_label_prediction_and_weight: Optional[Callable[
..., Any]] = metric_util.to_label_prediction_example_weight,
preprocessor: Optional[Callable[..., Any]] = None,
examples_name: Optional[str] = None,
example_id_key: Optional[str] = None,
example_ids_count: Optional[int] = None,
fractional_labels: float = True) -> metric_types.MetricComputations:
"""Returns metric computations for computing binary confusion matrices.
Args:
num_thresholds: Number of thresholds to use. Thresholds will be calculated
using linear interpolation between 0.0 and 1.0 with equidistant values and
bondardaries at -epsilon and 1.0+epsilon. Values must be > 0. Only one of
num_thresholds or thresholds should be used. If used, num_thresholds must
be > 1.
thresholds: A specific set of thresholds to use. The caller is responsible
for marking the boundaries with +/-epsilon if desired. Only one of
num_thresholds or thresholds should be used. For metrics computed at top k
this may be a single negative threshold value (i.e. -inf).
name: Metric name containing binary_confusion_matrices.Matrices.
eval_config: Eval config.
model_name: Optional model name (if multi-model evaluation).
output_name: Optional output name (if multi-output model type).
sub_key: Optional sub key.
aggregation_type: Optional aggregation type.
class_weights: Optional class weights to apply to multi-class / multi-label
labels and predictions prior to flattening (when micro averaging is used).
example_weighted: True if example weights should be applied.
use_histogram: If true, matrices will be derived from calibration
histograms.
extract_label_prediction_and_weight: User-provided function argument that
yields label, prediction, and example weights for use in calculations
(relevant only when use_histogram flag is not true).
preprocessor: User-provided preprocessor for including additional extracts
in StandardMetricInputs (relevant only when use_histogram flag is not
true).
examples_name: Metric name containing binary_confusion_matrices.Examples.
(relevant only when use_histogram flag is not true and example_id_key is
set).
example_id_key: Feature key containing example id (relevant only when
use_histogram flag is not true).
example_ids_count: Max number of example ids to be extracted for false
positives and false negatives (relevant only when use_histogram flag is
not true).
fractional_labels: If true, each incoming tuple of (label, prediction, and
example weight) will be split into two tuples as follows (where l, p, w
represent the resulting label, prediction, and example weight values): (1)
l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l =
1.0, p = prediction, and w = example_weight * label If enabled, an
exception will be raised if labels are not within [0, 1]. The
implementation is such that tuples associated with a weight of zero are
not yielded. This means it is safe to enable fractional_labels even when
the labels only take on the values of 0.0 or 1.0.
Raises:
ValueError: If both num_thresholds and thresholds are set at the same time.
"""
# TF v1 Keras AUC turns num_thresholds parameters into thresholds which
# circumvents sharing of settings. If the thresholds match the interpolated
# version of the thresholds then reset back to num_thresholds.
if thresholds:
if (not num_thresholds and
thresholds == _interpolated_thresholds(len(thresholds))):
num_thresholds = len(thresholds)
thresholds = None
elif (num_thresholds
in (DEFAULT_NUM_THRESHOLDS, _KERAS_DEFAULT_NUM_THRESHOLDS) and
len(thresholds) == num_thresholds - 2):
thresholds = None
if num_thresholds is not None and thresholds is not None:
raise ValueError(
'only one of thresholds or num_thresholds can be set at a time: '
f'num_thesholds={num_thresholds}, thresholds={thresholds}, '
f'len(thresholds)={len(thresholds)})')
if num_thresholds is None and thresholds is None:
num_thresholds = DEFAULT_NUM_THRESHOLDS
if num_thresholds is not None:
if num_thresholds <= 1:
raise ValueError('num_thresholds must be > 1')
# The interpolation strategy used here matches that used by keras for AUC.
thresholds = _interpolated_thresholds(num_thresholds)
thresholds_name_part = str(num_thresholds)
else:
thresholds_name_part = str(list(thresholds))
if use_histogram is None:
use_histogram = (
num_thresholds is not None or
(len(thresholds) == 1 and thresholds[0] < 0))
if use_histogram and (examples_name or example_id_key or example_ids_count):
raise ValueError('Example sampling is only performed when not using the '
'histogram computation. However, use_histogram is true '
f'and one of examples_name ("{examples_name}"), '
f'examples_id_key ("{example_id_key}"), '
f'or example_ids_count ({example_ids_count}) was '
'provided, which will have no effect.')
if examples_name and not (example_id_key and example_ids_count):
raise ValueError('examples_name provided but either example_id_key or '
'example_ids_count was not. Examples will only be '
'returned when both example_id_key and '
'example_ids_count are provided, and when the '
'non-histogram computation is used. '
f'example_id_key: "{example_id_key}" '
f'example_ids_count: {example_ids_count}')
if name is None:
name = f'{BINARY_CONFUSION_MATRICES_NAME}_{thresholds_name_part}'
if examples_name is None:
examples_name = f'{BINARY_CONFUSION_EXAMPLES_NAME}_{thresholds_name_part}'
matrices_key = metric_types.MetricKey(
name=name,
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
example_weighted=example_weighted)
examples_key = metric_types.MetricKey(
name=examples_name,
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
example_weighted=example_weighted)
computations = []
if use_histogram:
# Use calibration histogram to calculate matrices. For efficiency (unless
# all predictions are matched - i.e. thresholds <= 0) we will assume that
# other metrics will make use of the calibration histogram and re-use the
# default histogram for the given model_name/output_name/sub_key. This is
# also required to get accurate counts at the threshold boundaries. If this
# becomes an issue, then calibration histogram can be updated to support
# non-linear boundaries.
computations = calibration_histogram.calibration_histogram(
eval_config=eval_config,
num_buckets=(
# For precision/recall_at_k were a single large negative threshold
# is used, we only need one bucket. Note that the histogram will
# actually have 2 buckets: one that we set (which handles
# predictions > -1.0) and a default catch-all bucket (i.e. bucket 0)
# that the histogram creates for large negative predictions (i.e.
# predictions <= -1.0).
1 if len(thresholds) == 1 and thresholds[0] <= 0 else None),
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
aggregation_type=aggregation_type,
class_weights=class_weights,
example_weighted=example_weighted)
input_metric_key = computations[-1].keys[-1]
output_metric_keys = [matrices_key]
else:
if bool(example_ids_count) != bool(example_id_key):
raise ValueError('Both of example_ids_count and example_id_key must be '
f'set, but got example_id_key: "{example_id_key}" and '
f'example_ids_count: {example_ids_count}.')
computations = _binary_confusion_matrix_computation(
eval_config=eval_config,
thresholds=thresholds,
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
extract_label_prediction_and_weight=extract_label_prediction_and_weight,
preprocessor=preprocessor,
example_id_key=example_id_key,
example_ids_count=example_ids_count,
aggregation_type=aggregation_type,
class_weights=class_weights,
example_weighted=example_weighted,
fractional_labels=fractional_labels)
input_metric_key = computations[-1].keys[-1]
# matrices_key is last for backwards compatibility with code that:
# 1) used this computation as an input for a derived computation
# 2) only accessed the matrix counts
# 3) used computations[-1].keys[-1] to access the input key
output_metric_keys = [examples_key, matrices_key]
def result(
metrics: Dict[metric_types.MetricKey, Any]
) -> Dict[metric_types.MetricKey, Union[Matrices, Examples]]:
"""Returns binary confusion matrices."""
matrices = None
if use_histogram:
if len(thresholds) == 1 and thresholds[0] < 0:
# This case is used when all positive prediction values are relevant
# matches (e.g. when calculating top_k for precision/recall where the
# non-top_k values are expected to have been set to float('-inf')).
histogram = metrics[input_metric_key]
else:
# Calibration histogram uses intervals of the form [start, end) where
# the prediction >= start. The confusion matrices want intervals of the
# form (start, end] where the prediction > start. Add a small epsilon so
# that >= checks don't match. This correction shouldn't be needed in
# practice but allows for correctness in small tests.
rebin_thresholds = [t + _EPSILON if t != 0 else t for t in thresholds]
if thresholds[0] >= 0:
# Add -epsilon bucket to account for differences in histogram vs
# confusion matrix intervals mentioned above. If the epsilon bucket is
# missing the false negatives and false positives will be 0 for the
# first threshold.
rebin_thresholds = [-_EPSILON] + rebin_thresholds
if thresholds[-1] < 1.0:
# If the last threshold < 1.0, then add a fence post at 1.0 + epsilon
# othewise true negatives and true positives will be overcounted.
rebin_thresholds = rebin_thresholds + [1.0 + _EPSILON]
histogram = calibration_histogram.rebin(rebin_thresholds,
metrics[input_metric_key])
matrices = _histogram_to_binary_confusion_matrices(thresholds, histogram)
return {matrices_key: matrices}
else:
matrices, examples = _accumulator_to_matrices_and_examples(
thresholds, metrics[input_metric_key])
return {matrices_key: matrices, examples_key: examples}
derived_computation = metric_types.DerivedMetricComputation(
keys=output_metric_keys, result=result)
computations.append(derived_computation)
return computations