in model_card_toolkit/utils/tfx_util.py [0:0]
def read_stats_proto(
stats_artifact_uri: str,
split: str) -> Optional[statistics_pb2.DatasetFeatureStatisticsList]:
"""Reads DatasetFeatureStatisticsList proto from provided stats artifact uri.
Args:
stats_artifact_uri: the output artifact path of a StatsGen component.
split: the data split to fetch stats from.
Returns:
If the artifact uri does not exist, returns None. Otherwise, returns the
eval split stats as DatasetFeatureStatisticsList.
"""
stats = statistics_pb2.DatasetFeatureStatisticsList()
feature_stats_path = os.path.join(stats_artifact_uri, split,
'FeatureStats.pb')
stats_tfrecord_path = os.path.join(stats_artifact_uri, split,
'stats_tfrecord')
if tf.io.gfile.exists(feature_stats_path):
with tf.io.gfile.GFile(feature_stats_path, mode='rb') as f:
stats.ParseFromString(f.read())
return stats
elif tf.io.gfile.exists(stats_tfrecord_path):
serialized_stats = next(
tf.compat.v1.io.tf_record_iterator(stats_tfrecord_path))
stats.ParseFromString(serialized_stats)
return stats
else:
logging.warning('No artifact found at %s or %s', stats_tfrecord_path,
feature_stats_path)
return None