tfx/dsl/input_resolution/ops/graph_traversal_op.py (113 lines of code) (raw):

# Copyright 2023 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for GraphTraversal operator.""" from typing import Sequence from absl import logging from tfx import types from tfx.dsl.compiler import compiler_utils from tfx.dsl.compiler import constants from tfx.dsl.input_resolution import resolver_op from tfx.dsl.input_resolution.ops import ops_utils from tfx.orchestration.portable.input_resolution.mlmd_resolver import metadata_resolver from tfx.orchestration.portable.mlmd import event_lib from tfx.orchestration.portable.mlmd import filter_query_builder as q from tfx.types import artifact_utils from ml_metadata.proto import metadata_store_pb2 # Valid artifact states for GraphTraversal. _VALID_ARTIFACT_STATES = [metadata_store_pb2.Artifact.State.LIVE] class GraphTraversal( resolver_op.ResolverOp, canonical_name='tfx.GraphTraversal', arg_data_types=(resolver_op.DataType.ARTIFACT_LIST,), return_data_type=resolver_op.DataType.ARTIFACT_MULTIMAP, ): """GraphTraversal operator.""" # Whether to search for artifacts upstream or downstream. Required. traverse_upstream = resolver_op.Property(type=bool) # The artifact type names to search for, e.g. "ModelBlessing", # "TransformGraph". Should match Tflex standard artifact type names or user # defined custom artifact type names. Can not be empty. artifact_type_names = resolver_op.Property(type=Sequence[str]) # The producer component node IDs to match by, e.g. # "example-gen.import-example". Optional. node_ids = resolver_op.Property(type=Sequence[str], default=[]) # The Event output key(s) to match by. Optional. output_keys = resolver_op.Property(type=Sequence[str], default=[]) def apply(self, input_list: Sequence[types.Artifact]): """Returns a dict with the upstream (or downstream) and root artifacts. Args: input_list: A list with exactly one Artifact to use as the root. Returns: A dictionary with the upstream (or downstream) artifacts, and the root artifact. For example, consider: Examples -> Model -> ModelBlessing. Calling GraphTraversal with [ModelBlessing], traverse_upstream=True, and artifact_type_names=["Examples"] will return: { "root_artifact": [ModelBlessing], "examples": [Examples]. } Note the key "root_artifact" is set with the original artifact inside input_list. This makes input synchronzation easier in an ASYNC pipeline. """ if not input_list: return {} if not self.artifact_type_names: raise ValueError( 'At least one artifact type name must be provided, but ' 'artifact_type_names was empty.' ) # TODO(b/299985043): Support batch traversal. if len(input_list) != 1: raise ValueError( 'GraphTraversal ResolverOp does not support batch traversal.' ) root_artifact = input_list[0] # Query MLMD to get the upstream (or downstream) artifacts. artifact_states_filter_query = ( ops_utils.get_valid_artifact_states_filter_query(_VALID_ARTIFACT_STATES) ) filter_query = ( f'type IN {q.to_sql_string(self.artifact_type_names)} AND ' f'{artifact_states_filter_query}' ) if self.node_ids: for context in self.context.store.get_contexts_by_artifact( root_artifact.id ): if context.type == constants.PIPELINE_CONTEXT_TYPE_NAME: pipeline_name = context.name break else: raise ValueError('No pipeline context was found.') # We match against the node Context's name, which has the format # <pipeline-name>.<node-id> node_context_names = [ compiler_utils.node_context_name(pipeline_name, ni) for ni in self.node_ids ] query = ( f'contexts_a.name IN {q.to_sql_string(node_context_names)} ' 'AND contexts_a.type = ' f'{q.to_sql_string(constants.NODE_CONTEXT_TYPE_NAME)}' ) filter_query += ' AND ' + query mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store) mlmd_resolver_fn = ( mlmd_resolver.get_upstream_artifacts_by_artifact_ids if self.traverse_upstream else mlmd_resolver.get_downstream_artifacts_by_artifact_ids ) related_artifact_and_type = mlmd_resolver_fn( [root_artifact.id], max_num_hops=ops_utils.GRAPH_TRAVERSAL_OP_MAX_NUM_HOPS, filter_query=filter_query, ) artifact_type_by_id = {} related_artifacts = {} for artifact_id, artifacts_and_types in related_artifact_and_type.items(): related_artifacts[artifact_id], artifact_types = zip(*artifacts_and_types) artifact_type_by_id.update({t.id: t for t in artifact_types}) # Build the result dict to return. We include the root_artifact to help with # input synchronization in ASYNC mode. Note, Python dicts preserve key # insertion order, so when a user gets the unrolled dict values, they will # first get the root artifact, followed by ancestor/descendant artifacts in # the same order as self.artifact_type_names. result = {ops_utils.ROOT_ARTIFACT_KEY: [root_artifact]} for artifact_type in self.artifact_type_names: result[artifact_type] = [] if not related_artifacts.get(root_artifact.id): logging.info( 'No neighboring artifacts were found for root artifact %s and ' 'artifact_type_names %s node_ids %s output_keys %s.', root_artifact, self.artifact_type_names, self.node_ids, self.output_keys, ) return result related_artifacts = related_artifacts[root_artifact.id] # Get the ArtifactType for the related artifacts. artifact_type_by_artifact_id = {} for artifact in related_artifacts: artifact_type_by_artifact_id[artifact.id] = artifact_type_by_id[ artifact.type_id ] # Build the result dictionary, with a separate key for each ArtifactType. artifact_ids = set(a.id for a in related_artifacts) events = self.context.store.get_events_by_artifact_ids(artifact_ids) events_by_artifact_id = { e.artifact_id: e for e in events if event_lib.is_valid_output_event(e) } for artifact in related_artifacts: # MLMD does not support filter querying by event.paths, so we manually # check for matching output key. # TODO(b/302394845): Once MLMD supports filtering by the last event, then # add this check inside the filter_query or event_filter. if self.output_keys and not any( event_lib.contains_key(events_by_artifact_id[artifact.id], k) for k in self.output_keys ): continue deserialized_artifact = artifact_utils.deserialize_artifact( artifact_type_by_artifact_id[artifact.id], artifact ) result[artifact.type].append(deserialized_artifact) return ops_utils.sort_artifact_dict(result)