in tensorflow_model_analysis/metrics/metric_specs.py [0:0]
def _process_tfma_metrics_specs(
tfma_metrics_specs: List[config_pb2.MetricsSpec],
per_tfma_spec_metric_instances: List[List[metric_types.Metric]],
eval_config: config_pb2.EvalConfig,
schema: Optional[schema_pb2.Schema]) -> metric_types.MetricComputations:
"""Processes list of TFMA MetricsSpecs to create computations."""
#
# Computations are per metric, so separate by metrics and the specs associated
# with them.
#
# Dict[bytes,List[config_pb2.MetricSpec]] (hash(MetricConfig)->[MetricSpec])
tfma_specs_by_metric_config = {}
# Dict[bytes,metric_types.Metric] (hash(MetricConfig)->Metric)
hashed_metrics = {}
hashed_configs = {}
for i, spec in enumerate(tfma_metrics_specs):
for metric_config, metric in zip(spec.metrics,
per_tfma_spec_metric_instances[i]):
# Note that hashing by SerializeToString() is only safe if used within the
# same process.
config_hash = metric_config.SerializeToString()
if config_hash not in tfma_specs_by_metric_config:
hashed_metrics[config_hash] = metric
hashed_configs[config_hash] = metric_config
tfma_specs_by_metric_config[config_hash] = []
tfma_specs_by_metric_config[config_hash].append(spec)
#
# Create computations for each metric.
#
result = []
for config_hash, specs in tfma_specs_by_metric_config.items():
metric = hashed_metrics[config_hash]
metric_config = hashed_configs[config_hash]
for spec in specs:
sub_keys_by_aggregation_type = _create_sub_keys(spec)
# Keep track of sub-keys that can be shared between macro averaging and
# binarization. For example, if macro averaging is being performed over
# 10 classes and 5 of the classes are also being binarized, then those 5
# classes can be re-used by the macro averaging calculation. The
# remaining 5 classes need to be added as private metrics since those
# classes were not requested but are still needed for the macro
# averaging calculation.
if None in sub_keys_by_aggregation_type:
shared_sub_keys = set(sub_keys_by_aggregation_type[None])
else:
shared_sub_keys = set()
for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items():
class_weights = _class_weights(spec) if aggregation_type else None
is_macro = (
aggregation_type and (aggregation_type.macro_average or
aggregation_type.weighted_macro_average))
if is_macro:
updated_sub_keys = []
for sub_key in sub_keys:
for key in _macro_average_sub_keys(sub_key, class_weights):
if key not in shared_sub_keys:
updated_sub_keys.append(key)
if not updated_sub_keys:
continue
aggregation_type = aggregation_type if not is_macro else None
class_weights = None
sub_keys = updated_sub_keys
instance = _private_tfma_metric(metric)
else:
instance = metric
for example_weighted in _example_weight_options(eval_config, spec):
result.extend(
instance.computations(
eval_config=eval_config,
schema=schema,
model_names=list(spec.model_names) or [''],
output_names=list(spec.output_names) or [''],
sub_keys=sub_keys,
aggregation_type=aggregation_type,
class_weights=class_weights if class_weights else None,
example_weighted=example_weighted,
query_key=spec.query_key))
return result