in tfx/dsl/input_resolution/ops/siblings_op.py [0:0]
def apply(self, input_list: Sequence[types.Artifact]):
"""Returns output artifacts produced in the same execution."""
if not input_list:
return {}
# TODO(b/299985043): Support batch traversal.
if len(input_list) != 1:
raise ValueError('Siblings ResolverOp does not support batch queries.')
root_artifact = input_list[0]
artifact_states_filter_query = (
ops_utils.get_valid_artifact_states_filter_query(_VALID_ARTIFACT_STATES)
)
lineage_graph = self.context.store.get_lineage_subgraph(
query_options=metadata_store_pb2.LineageSubgraphQueryOptions(
starting_artifacts=(
metadata_store_pb2.LineageSubgraphQueryOptions.StartingNodes(
filter_query=(
f'id = {root_artifact.id} AND '
f'{artifact_states_filter_query}'
),
)
),
ending_executions=(
metadata_store_pb2.LineageSubgraphQueryOptions.EndingNodes(
# NOTE: This query assumes that an artifact will never be
# the input of an execution and the output of another (or
# the same) execution. This is always the case in Tflex,
# because the orchestrator produces new output artifacts
# for every execution.
filter_query=(
f'events_0.artifact_id = {root_artifact.id} AND'
' events_0.type = INPUT'
)
)
),
max_num_hops=2,
direction=metadata_store_pb2.LineageSubgraphQueryOptions.BIDIRECTIONAL,
),
field_mask_paths=[
'artifacts',
'artifact_types',
'events',
],
)
if not self.output_keys:
# Find all output keys.
output_keys = set()
for event in lineage_graph.events:
if (
event_lib.is_valid_output_event(event)
# We exclude output keys associated with the root artifact. This
# ensures the root artifact will only be associated with the key
# "root_artifact" in the returned dictionary.
and event.artifact_id != root_artifact.id
):
keys_and_indexes = event_lib._parse_path(event) # pylint: disable=protected-access
for key, _ in keys_and_indexes:
output_keys.add(key)
self.output_keys = list(output_keys)
# Build the result dict to return. We include the root_artifact to help with
# input synchronization in ASYNC mode. Note, Python dicts preserve key
# insertion order, so when a user gets the unrolled dict values, they will
# first get the root artifact, followed by sibling artifacts in the same
# order as self.output_keys.
result = {ops_utils.ROOT_ARTIFACT_KEY: [root_artifact]}
for output_key in self.output_keys:
result[output_key] = []
# Get output Artifact IDs associated with each output key.
artifact_by_id = {a.id: a for a in lineage_graph.artifacts}
artifact_type_by_id = {at.id: at for at in lineage_graph.artifact_types}
for event in lineage_graph.events:
if not event_lib.is_valid_output_event(event):
continue
for output_key in self.output_keys:
if event_lib.contains_key(event, output_key):
artifact = artifact_by_id[event.artifact_id]
artifact_type = artifact_type_by_id[artifact.type_id]
deserialized_artifact = artifact_utils.deserialize_artifact(
artifact_type, artifact
)
result[output_key].append(deserialized_artifact)
return ops_utils.sort_artifact_dict(result)