in model_card_toolkit/utils/tfx_util.py [0:0]
def filter_metrics(
eval_result: tfma.EvalResult,
metrics_include: Optional[List[str]] = None,
metrics_exclude: Optional[List[str]] = None) -> tfma.EvalResult:
"""Filters metrics in a TFMA EvalResult.
Args:
eval_result: The TFMA EvalResult object.
metrics_include: The names of metrics to keep in the EvalResult. Mutually
exclusive with metrics_exclude.
metrics_exclude: The names of metrics to discard in the EvalResult. Mutually
exclusive with metrics_include.
Returns:
The eval_result with unwanted metrics filtered.
Raises:
ValueError: if both metrics_include and metrics_exclude are provided.
"""
if metrics_include and not metrics_exclude:
include = lambda metric_name: metric_name in metrics_include
elif metrics_exclude and not metrics_include:
include = lambda metric_name: metric_name not in metrics_exclude
else:
raise ValueError('filter_metrics() requires exactly one of metrics_include '
'and metrics_exclude.')
filtered_slicing_metrics = []
for slc, mtrc in eval_result.slicing_metrics:
filtered_mtrc = {}
for output_name in mtrc:
for subkey in mtrc[output_name]:
for mtrc_name in mtrc[output_name][subkey]:
if include(mtrc_name):
filtered_mtrc[output_name] = filtered_mtrc.get(output_name, {})
filtered_mtrc[output_name][subkey] = filtered_mtrc[output_name].get(
subkey, {})
filtered_mtrc[output_name][subkey][mtrc_name] = mtrc[output_name][
subkey][mtrc_name]
filtered_slicing_metrics.append(
tfma.view.SlicedMetrics(slice=slc, metrics=filtered_mtrc))
return tfma.EvalResult(
slicing_metrics=filtered_slicing_metrics,
plots=eval_result.plots,
attributions=eval_result.attributions,
config=eval_result.config,
data_location=eval_result.data_location,
file_format=eval_result.file_format,
model_location=eval_result.model_location)