in tensorflow_model_analysis/utils/model_util.py [0:0]
def get_feature_values_for_model_spec_field(
model_specs: List[config_pb2.ModelSpec],
field: str,
multi_output_field: Optional[str],
batched_extracts: types.Extracts,
allow_missing: bool = False) -> Optional[Any]:
"""Gets feature values associated with given model spec fields from extracts.
Args:
model_specs: List of model specs from EvalConfig.
field: Name of field used to determine the feature(s) to extract. This
should be an attribute on the ModelSpec such as "label_key",
"example_weight_key", or "prediction_key".
multi_output_field: Optional name of field used to store multi-output
versions of the features. This should be an attribute on the ModelSpec
such as "label_keys", "example_weight_keys", or "prediction_keys". This
field is only used if a value at field is not found.
batched_extracts: Extracts containing batched features keyed by
tfma.FEATURES_KEY and optionally tfma.TRANSFORMED_FEATURES_KEY.
allow_missing: True if the feature may be missing (in which case None will
be used as the value).
Returns:
Feature values stored at given key (or feature values stored at each output
keyed by output name if field containing map of feature keys was used). If
multiple models are used the value(s) will be stored in a dict keyed by
model name. If no values are found and allow_missing is False then None
will be returned.
"""
if (constants.FEATURES_KEY in batched_extracts and
batched_extracts[constants.FEATURES_KEY]):
batch_size = len(batched_extracts[constants.FEATURES_KEY])
elif (constants.TRANSFORMED_FEATURES_KEY in batched_extracts and
batched_extracts[constants.TRANSFORMED_FEATURES_KEY]):
batch_size = len(batched_extracts[constants.TRANSFORMED_FEATURES_KEY])
else:
batch_size = batched_extracts[constants.ARROW_RECORD_BATCH_KEY].num_rows
batched_values = []
all_none = True
for i in range(batch_size):
values = {}
if (constants.FEATURES_KEY in batched_extracts and
batched_extracts[constants.FEATURES_KEY]):
features = batched_extracts[constants.FEATURES_KEY][i]
else:
features = {}
for spec in model_specs:
# Get transformed features (if any) for this model.
if (constants.TRANSFORMED_FEATURES_KEY in batched_extracts and
batched_extracts[constants.TRANSFORMED_FEATURES_KEY]):
transformed_features = batched_extracts[
constants.TRANSFORMED_FEATURES_KEY][i]
if len(model_specs) > 1 and transformed_features:
if spec.name in transformed_features:
transformed_features = transformed_features[spec.name]
transformed_features = transformed_features or {}
else:
transformed_features = {}
# Lookup first in transformed_features and then in features.
if hasattr(spec, field) and getattr(spec, field):
key = getattr(spec, field)
if key in transformed_features:
values[spec.name] = transformed_features[key]
elif key in features:
values[spec.name] = features[key]
elif allow_missing:
values[spec.name] = None
elif (multi_output_field and hasattr(spec, multi_output_field) and
getattr(spec, multi_output_field)):
output_values = {}
for output_name, key in getattr(spec, multi_output_field).items():
if key in transformed_features:
output_values[output_name] = transformed_features[key]
elif key in features:
output_values[output_name] = features[key]
elif allow_missing:
output_values[output_name] = None
if output_values:
values[spec.name] = output_values
elif allow_missing:
values[spec.name] = None
if values:
all_none = False
# If only one model, the output is stored without using a dict
if len(model_specs) == 1:
values = next(iter(values.values()))
else:
values = None
batched_values.append(values)
return batched_values if not all_none else None