def get_feature_values_for_model_spec_field()

in tensorflow_model_analysis/utils/model_util.py [0:0]


def get_feature_values_for_model_spec_field(
    model_specs: List[config_pb2.ModelSpec],
    field: str,
    multi_output_field: Optional[str],
    batched_extracts: types.Extracts,
    allow_missing: bool = False) -> Optional[Any]:
  """Gets feature values associated with given model spec fields from extracts.

  Args:
    model_specs: List of model specs from EvalConfig.
    field: Name of field used to determine the feature(s) to extract. This
      should be an attribute on the ModelSpec such as "label_key",
      "example_weight_key", or "prediction_key".
    multi_output_field: Optional name of field used to store multi-output
      versions of the features. This should be an attribute on the ModelSpec
      such as "label_keys", "example_weight_keys", or "prediction_keys". This
      field is only used if a value at field is not found.
    batched_extracts: Extracts containing batched features keyed by
      tfma.FEATURES_KEY and optionally tfma.TRANSFORMED_FEATURES_KEY.
    allow_missing: True if the feature may be missing (in which case None will
      be used as the value).

  Returns:
    Feature values stored at given key (or feature values stored at each output
    keyed by output name if field containing map of feature keys was used). If
    multiple models are used the value(s) will be stored in a dict keyed by
    model name. If no values are found and allow_missing is False then None
    will be returned.
  """
  if (constants.FEATURES_KEY in batched_extracts and
      batched_extracts[constants.FEATURES_KEY]):
    batch_size = len(batched_extracts[constants.FEATURES_KEY])
  elif (constants.TRANSFORMED_FEATURES_KEY in batched_extracts and
        batched_extracts[constants.TRANSFORMED_FEATURES_KEY]):
    batch_size = len(batched_extracts[constants.TRANSFORMED_FEATURES_KEY])
  else:
    batch_size = batched_extracts[constants.ARROW_RECORD_BATCH_KEY].num_rows

  batched_values = []
  all_none = True
  for i in range(batch_size):
    values = {}
    if (constants.FEATURES_KEY in batched_extracts and
        batched_extracts[constants.FEATURES_KEY]):
      features = batched_extracts[constants.FEATURES_KEY][i]
    else:
      features = {}
    for spec in model_specs:
      # Get transformed features (if any) for this model.
      if (constants.TRANSFORMED_FEATURES_KEY in batched_extracts and
          batched_extracts[constants.TRANSFORMED_FEATURES_KEY]):
        transformed_features = batched_extracts[
            constants.TRANSFORMED_FEATURES_KEY][i]
        if len(model_specs) > 1 and transformed_features:
          if spec.name in transformed_features:
            transformed_features = transformed_features[spec.name]
        transformed_features = transformed_features or {}
      else:
        transformed_features = {}
      # Lookup first in transformed_features and then in features.
      if hasattr(spec, field) and getattr(spec, field):
        key = getattr(spec, field)
        if key in transformed_features:
          values[spec.name] = transformed_features[key]
        elif key in features:
          values[spec.name] = features[key]
        elif allow_missing:
          values[spec.name] = None
      elif (multi_output_field and hasattr(spec, multi_output_field) and
            getattr(spec, multi_output_field)):
        output_values = {}
        for output_name, key in getattr(spec, multi_output_field).items():
          if key in transformed_features:
            output_values[output_name] = transformed_features[key]
          elif key in features:
            output_values[output_name] = features[key]
          elif allow_missing:
            output_values[output_name] = None
        if output_values:
          values[spec.name] = output_values
      elif allow_missing:
        values[spec.name] = None
    if values:
      all_none = False
      # If only one model, the output is stored without using a dict
      if len(model_specs) == 1:
        values = next(iter(values.values()))
    else:
      values = None
    batched_values.append(values)
  return batched_values if not all_none else None