def _get_tfx_pipeline_types()

in model_card_toolkit/utils/tfx_util.py [0:0]


def _get_tfx_pipeline_types(store: mlmd.MetadataStore) -> PipelineTypes:
  """Retrieves the registered types in the given `store`.

  Args:
    store: A ml-metadata MetadataStore to retrieve ArtifactTypes from.

  Returns:
    A instance of PipelineTypes containing store pipeline types.

  Raises:
    ValueError: If the `store` does not have MCT related types and is not
      considered a valid TFX store.
  """
  artifact_types = {atype.name: atype for atype in store.get_artifact_types()}
  expected_artifact_types = {
      _TFX_DATASET_TYPE, _TFX_STATS_TYPE, _TFX_MODEL_TYPE, _TFX_METRICS_TYPE
  }
  missing_types = expected_artifact_types.difference(artifact_types.keys())
  if missing_types:
    raise ValueError(
        f'Given `store` is invalid: missing ArtifactTypes: {missing_types}.')
  execution_types = {etype.name: etype for etype in store.get_execution_types()}
  expected_execution_types = {_TFX_TRAINER_TYPE}
  missing_types = expected_execution_types.difference(execution_types.keys())
  if missing_types:
    raise ValueError(
        f'Given `store` is invalid: missing ExecutionTypes: {missing_types}.')
  return PipelineTypes(
      dataset_type=artifact_types[_TFX_DATASET_TYPE],
      stats_type=artifact_types[_TFX_STATS_TYPE],
      model_type=artifact_types[_TFX_MODEL_TYPE],
      metrics_type=artifact_types[_TFX_METRICS_TYPE],
      trainer_type=execution_types[_TFX_TRAINER_TYPE])