in tensorflow_model_analysis/metrics/tf_metric_wrapper.py [0:0]
def _wrap_confusion_matrix_metric(
metric: tf.keras.metrics.Metric, eval_config: config_pb2.EvalConfig,
model_name: str, output_name: str, sub_key: Optional[metric_types.SubKey],
aggregation_type: Optional[metric_types.AggregationType],
class_weights: Optional[Dict[int, float]],
example_weighted: bool) -> metric_types.MetricComputations:
"""Returns confusion matrix metric wrapped in a more efficient computation."""
# Special handling for AUC metric which supports aggregation inherently via
# multi_label flag.
if (isinstance(metric, tf.keras.metrics.AUC) and
hasattr(metric, 'label_weights')):
if metric.label_weights:
if class_weights:
raise ValueError(
'class weights are configured in two different places: (1) via the '
'tf.keras.metrics.AUC class (using "label_weights") and (2) via '
'the MetricsSpecs (using "aggregate.class_weights"). Either remove '
'the label_weights settings in the AUC class or remove the '
'class_weights from the AggregationOptions: metric={}, '
'class_weights={}'.format(metric, class_weights))
class_weights = {i: v for i, v in enumerate(metric.label_weights)}
if metric.multi_label:
raise NotImplementedError('AUC.multi_label=True is not implemented yet.')
sub_key = _verify_and_update_sub_key(model_name, output_name, sub_key, metric)
key = metric_types.MetricKey(
name=metric.name,
model_name=model_name,
output_name=output_name,
aggregation_type=aggregation_type,
sub_key=sub_key,
example_weighted=example_weighted)
metric_config = tf.keras.metrics.serialize(metric)
thresholds = None
num_thresholds = None
# The top_k metrics have special settings. If we are setting the top_k value
# outside of keras (i.e. using BinarizeOptions), then we need to set the
# special threshold ourselves otherwise the default threshold of 0.5 is used.
if (sub_key and sub_key.top_k is not None and
_get_config_value(_TOP_K_KEY, metric_config) is None and
_get_config_value(_THRESHOLDS_KEY, metric_config) is None and
_get_config_value(_NUM_THRESHOLDS_KEY, metric_config) is None):
thresholds = [float('-inf')]
elif hasattr(metric, _THRESHOLDS_KEY):
thresholds = metric.thresholds
# Only one of either thresholds or num_thresholds should be used. Keras AUC
# allows both but thresholds has more precedence.
if thresholds is None and hasattr(metric, _NUM_THRESHOLDS_KEY):
num_thresholds = metric.num_thresholds
# Make sure matrices are calculated.
computations = binary_confusion_matrices.binary_confusion_matrices(
num_thresholds=num_thresholds,
thresholds=thresholds,
eval_config=eval_config,
model_name=model_name,
output_name=output_name,
sub_key=sub_key,
aggregation_type=aggregation_type,
class_weights=class_weights,
example_weighted=example_weighted)
matrices_key = computations[-1].keys[-1]
def result(
metrics: Dict[metric_types.MetricKey, Any]
) -> Dict[metric_types.MetricKey, Any]:
"""Returns result derived from binary confusion matrices."""
matrices = metrics[matrices_key]
metric = tf.keras.metrics.deserialize(metric_config)
if (isinstance(metric, tf.keras.metrics.AUC) or
isinstance(metric, tf.keras.metrics.SpecificityAtSensitivity) or
isinstance(metric, tf.keras.metrics.SensitivityAtSpecificity)):
metric.true_positives.assign(np.array(matrices.tp))
metric.true_negatives.assign(np.array(matrices.tn))
metric.false_positives.assign(np.array(matrices.fp))
metric.false_negatives.assign(np.array(matrices.fn))
elif isinstance(metric, tf.keras.metrics.Precision):
metric.true_positives.assign(np.array(matrices.tp))
metric.false_positives.assign(np.array(matrices.fp))
elif isinstance(metric, tf.keras.metrics.Recall):
metric.true_positives.assign(np.array(matrices.tp))
metric.false_negatives.assign(np.array(matrices.fn))
elif isinstance(metric, tf.keras.metrics.TruePositives):
metric.accumulator.assign(np.array(matrices.tp))
elif isinstance(metric, tf.keras.metrics.FalsePositives):
metric.accumulator.assign(np.array(matrices.fp))
elif isinstance(metric, tf.keras.metrics.TrueNegatives):
metric.accumulator.assign(np.array(matrices.tn))
elif isinstance(metric, tf.keras.metrics.FalseNegatives):
metric.accumulator.assign(np.array(matrices.fn))
return {key: metric.result().numpy()}
derived_computation = metric_types.DerivedMetricComputation(
keys=[key], result=result)
computations.append(derived_computation)
return computations