def read_stats_proto()

in model_card_toolkit/utils/tfx_util.py [0:0]


def read_stats_proto(
    stats_artifact_uri: str,
    split: str) -> Optional[statistics_pb2.DatasetFeatureStatisticsList]:
  """Reads DatasetFeatureStatisticsList proto from provided stats artifact uri.

  Args:
    stats_artifact_uri: the output artifact path of a StatsGen component.
    split: the data split to fetch stats from.

  Returns:
    If the artifact uri does not exist, returns None. Otherwise, returns the
    eval split stats as DatasetFeatureStatisticsList.
  """
  stats = statistics_pb2.DatasetFeatureStatisticsList()
  feature_stats_path = os.path.join(stats_artifact_uri, split,
                                    'FeatureStats.pb')
  stats_tfrecord_path = os.path.join(stats_artifact_uri, split,
                                     'stats_tfrecord')

  if tf.io.gfile.exists(feature_stats_path):
    with tf.io.gfile.GFile(feature_stats_path, mode='rb') as f:
      stats.ParseFromString(f.read())
    return stats
  elif tf.io.gfile.exists(stats_tfrecord_path):
    serialized_stats = next(
        tf.compat.v1.io.tf_record_iterator(stats_tfrecord_path))
    stats.ParseFromString(serialized_stats)
    return stats
  else:
    logging.warning('No artifact found at %s or %s', stats_tfrecord_path,
                    feature_stats_path)
    return None