in tfx/dsl/input_resolution/ops/graph_traversal_op.py [0:0]
def apply(self, input_list: Sequence[types.Artifact]):
"""Returns a dict with the upstream (or downstream) and root artifacts.
Args:
input_list: A list with exactly one Artifact to use as the root.
Returns:
A dictionary with the upstream (or downstream) artifacts, and the root
artifact.
For example, consider: Examples -> Model -> ModelBlessing.
Calling GraphTraversal with [ModelBlessing], traverse_upstream=True, and
artifact_type_names=["Examples"] will return:
{
"root_artifact": [ModelBlessing],
"examples": [Examples].
}
Note the key "root_artifact" is set with the original artifact inside
input_list. This makes input synchronzation easier in an ASYNC pipeline.
"""
if not input_list:
return {}
if not self.artifact_type_names:
raise ValueError(
'At least one artifact type name must be provided, but '
'artifact_type_names was empty.'
)
# TODO(b/299985043): Support batch traversal.
if len(input_list) != 1:
raise ValueError(
'GraphTraversal ResolverOp does not support batch traversal.'
)
root_artifact = input_list[0]
# Query MLMD to get the upstream (or downstream) artifacts.
artifact_states_filter_query = (
ops_utils.get_valid_artifact_states_filter_query(_VALID_ARTIFACT_STATES)
)
filter_query = (
f'type IN {q.to_sql_string(self.artifact_type_names)} AND '
f'{artifact_states_filter_query}'
)
if self.node_ids:
for context in self.context.store.get_contexts_by_artifact(
root_artifact.id
):
if context.type == constants.PIPELINE_CONTEXT_TYPE_NAME:
pipeline_name = context.name
break
else:
raise ValueError('No pipeline context was found.')
# We match against the node Context's name, which has the format
# <pipeline-name>.<node-id>
node_context_names = [
compiler_utils.node_context_name(pipeline_name, ni)
for ni in self.node_ids
]
query = (
f'contexts_a.name IN {q.to_sql_string(node_context_names)} '
'AND contexts_a.type = '
f'{q.to_sql_string(constants.NODE_CONTEXT_TYPE_NAME)}'
)
filter_query += ' AND ' + query
mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
mlmd_resolver_fn = (
mlmd_resolver.get_upstream_artifacts_by_artifact_ids
if self.traverse_upstream
else mlmd_resolver.get_downstream_artifacts_by_artifact_ids
)
related_artifact_and_type = mlmd_resolver_fn(
[root_artifact.id],
max_num_hops=ops_utils.GRAPH_TRAVERSAL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
)
artifact_type_by_id = {}
related_artifacts = {}
for artifact_id, artifacts_and_types in related_artifact_and_type.items():
related_artifacts[artifact_id], artifact_types = zip(*artifacts_and_types)
artifact_type_by_id.update({t.id: t for t in artifact_types})
# Build the result dict to return. We include the root_artifact to help with
# input synchronization in ASYNC mode. Note, Python dicts preserve key
# insertion order, so when a user gets the unrolled dict values, they will
# first get the root artifact, followed by ancestor/descendant artifacts in
# the same order as self.artifact_type_names.
result = {ops_utils.ROOT_ARTIFACT_KEY: [root_artifact]}
for artifact_type in self.artifact_type_names:
result[artifact_type] = []
if not related_artifacts.get(root_artifact.id):
logging.info(
'No neighboring artifacts were found for root artifact %s and '
'artifact_type_names %s node_ids %s output_keys %s.',
root_artifact,
self.artifact_type_names,
self.node_ids,
self.output_keys,
)
return result
related_artifacts = related_artifacts[root_artifact.id]
# Get the ArtifactType for the related artifacts.
artifact_type_by_artifact_id = {}
for artifact in related_artifacts:
artifact_type_by_artifact_id[artifact.id] = artifact_type_by_id[
artifact.type_id
]
# Build the result dictionary, with a separate key for each ArtifactType.
artifact_ids = set(a.id for a in related_artifacts)
events = self.context.store.get_events_by_artifact_ids(artifact_ids)
events_by_artifact_id = {
e.artifact_id: e for e in events if event_lib.is_valid_output_event(e)
}
for artifact in related_artifacts:
# MLMD does not support filter querying by event.paths, so we manually
# check for matching output key.
# TODO(b/302394845): Once MLMD supports filtering by the last event, then
# add this check inside the filter_query or event_filter.
if self.output_keys and not any(
event_lib.contains_key(events_by_artifact_id[artifact.id], k)
for k in self.output_keys
):
continue
deserialized_artifact = artifact_utils.deserialize_artifact(
artifact_type_by_artifact_id[artifact.id], artifact
)
result[artifact.type].append(deserialized_artifact)
return ops_utils.sort_artifact_dict(result)