in tfx/orchestration/experimental/core/sync_pipeline_task_gen.py [0:0]
def __call__(self) -> List[task_lib.Task]:
layers = _topsorted_layers(self._pipeline)
terminal_node_ids = _terminal_node_ids(layers)
exec_node_tasks = []
update_node_state_tasks = []
successful_node_ids = set()
failed_nodes_dict: Dict[str, status_lib.Status] = {}
finalize_pipeline_task = None
for layer_nodes in layers:
for node in layer_nodes:
node_id = node.node_info.id
node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
node_state = self._node_states_dict[node_uid]
if node_state.is_success():
successful_node_ids.add(node_id)
continue
if node_state.is_failure():
failed_nodes_dict[node_id] = node_state.status
continue
if not self._upstream_nodes_successful(node, successful_node_ids):
continue
tasks = self._generate_tasks_for_node(node)
for task in tasks:
if task_lib.is_update_node_state_task(task):
task = typing.cast(task_lib.UpdateNodeStateTask, task)
if pstate.is_node_state_success(task.state):
successful_node_ids.add(node_id)
elif pstate.is_node_state_failure(task.state):
failed_nodes_dict[node_id] = task.status
if self._fail_fast:
finalize_pipeline_task = self._abort_task(task.status.message)
update_node_state_tasks.append(task)
elif task_lib.is_exec_node_task(task):
exec_node_tasks.append(task)
if finalize_pipeline_task:
break
if finalize_pipeline_task:
break
if not self._fail_fast and failed_nodes_dict:
assert not finalize_pipeline_task
node_by_id = _node_by_id(self._pipeline)
# Collect nodes that cannot be run because they have a failed ancestor.
unrunnable_node_ids = set()
for node_id in failed_nodes_dict:
unrunnable_node_ids |= _descendants(node_by_id, node_id)
# Nodes that are still runnable have neither succeeded nor failed, and
# don't have a failed ancestor.
runnable_node_ids = node_by_id.keys() - (
unrunnable_node_ids | successful_node_ids | failed_nodes_dict.keys())
# If there are no runnable nodes, we can abort the pipeline.
if not runnable_node_ids:
finalize_pipeline_task = self._abort_task(
f'Cannot make progress due to node failures: {failed_nodes_dict}')
result = update_node_state_tasks
if finalize_pipeline_task:
result.append(finalize_pipeline_task)
elif terminal_node_ids <= successful_node_ids:
# If all terminal nodes are successful, the pipeline can be finalized.
result.append(
task_lib.FinalizePipelineTask(
pipeline_uid=self._pipeline_uid,
status=status_lib.Status(code=status_lib.Code.OK)))
else:
result.extend(exec_node_tasks)
return result