def filter_metrics()

in model_card_toolkit/utils/tfx_util.py [0:0]


def filter_metrics(
    eval_result: tfma.EvalResult,
    metrics_include: Optional[List[str]] = None,
    metrics_exclude: Optional[List[str]] = None) -> tfma.EvalResult:
  """Filters metrics in a TFMA EvalResult.

  Args:
    eval_result: The TFMA EvalResult object.
    metrics_include: The names of metrics to keep in the EvalResult. Mutually
      exclusive with metrics_exclude.
    metrics_exclude: The names of metrics to discard in the EvalResult. Mutually
      exclusive with metrics_include.

  Returns:
    The eval_result with unwanted metrics filtered.

  Raises:
    ValueError: if both metrics_include and metrics_exclude are provided.
  """
  if metrics_include and not metrics_exclude:
    include = lambda metric_name: metric_name in metrics_include
  elif metrics_exclude and not metrics_include:
    include = lambda metric_name: metric_name not in metrics_exclude
  else:
    raise ValueError('filter_metrics() requires exactly one of metrics_include '
                     'and metrics_exclude.')

  filtered_slicing_metrics = []
  for slc, mtrc in eval_result.slicing_metrics:
    filtered_mtrc = {}
    for output_name in mtrc:
      for subkey in mtrc[output_name]:
        for mtrc_name in mtrc[output_name][subkey]:
          if include(mtrc_name):
            filtered_mtrc[output_name] = filtered_mtrc.get(output_name, {})
            filtered_mtrc[output_name][subkey] = filtered_mtrc[output_name].get(
                subkey, {})
            filtered_mtrc[output_name][subkey][mtrc_name] = mtrc[output_name][
                subkey][mtrc_name]
    filtered_slicing_metrics.append(
        tfma.view.SlicedMetrics(slice=slc, metrics=filtered_mtrc))

  return tfma.EvalResult(
      slicing_metrics=filtered_slicing_metrics,
      plots=eval_result.plots,
      attributions=eval_result.attributions,
      config=eval_result.config,
      data_location=eval_result.data_location,
      file_format=eval_result.file_format,
      model_location=eval_result.model_location)