in model_card_toolkit/utils/tfx_util.py [0:0]
def _get_tfx_pipeline_types(store: mlmd.MetadataStore) -> PipelineTypes:
"""Retrieves the registered types in the given `store`.
Args:
store: A ml-metadata MetadataStore to retrieve ArtifactTypes from.
Returns:
A instance of PipelineTypes containing store pipeline types.
Raises:
ValueError: If the `store` does not have MCT related types and is not
considered a valid TFX store.
"""
artifact_types = {atype.name: atype for atype in store.get_artifact_types()}
expected_artifact_types = {
_TFX_DATASET_TYPE, _TFX_STATS_TYPE, _TFX_MODEL_TYPE, _TFX_METRICS_TYPE
}
missing_types = expected_artifact_types.difference(artifact_types.keys())
if missing_types:
raise ValueError(
f'Given `store` is invalid: missing ArtifactTypes: {missing_types}.')
execution_types = {etype.name: etype for etype in store.get_execution_types()}
expected_execution_types = {_TFX_TRAINER_TYPE}
missing_types = expected_execution_types.difference(execution_types.keys())
if missing_types:
raise ValueError(
f'Given `store` is invalid: missing ExecutionTypes: {missing_types}.')
return PipelineTypes(
dataset_type=artifact_types[_TFX_DATASET_TYPE],
stats_type=artifact_types[_TFX_STATS_TYPE],
model_type=artifact_types[_TFX_MODEL_TYPE],
metrics_type=artifact_types[_TFX_METRICS_TYPE],
trainer_type=execution_types[_TFX_TRAINER_TYPE])