def __call__()

in tfx/orchestration/experimental/core/sync_pipeline_task_gen.py [0:0]


  def __call__(self) -> List[task_lib.Task]:
    layers = _topsorted_layers(self._pipeline)
    terminal_node_ids = _terminal_node_ids(layers)
    exec_node_tasks = []
    update_node_state_tasks = []
    successful_node_ids = set()
    failed_nodes_dict: Dict[str, status_lib.Status] = {}
    finalize_pipeline_task = None
    for layer_nodes in layers:
      for node in layer_nodes:
        node_id = node.node_info.id
        node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
        node_state = self._node_states_dict[node_uid]
        if node_state.is_success():
          successful_node_ids.add(node_id)
          continue
        if node_state.is_failure():
          failed_nodes_dict[node_id] = node_state.status
          continue
        if not self._upstream_nodes_successful(node, successful_node_ids):
          continue
        tasks = self._generate_tasks_for_node(node)
        for task in tasks:
          if task_lib.is_update_node_state_task(task):
            task = typing.cast(task_lib.UpdateNodeStateTask, task)
            if pstate.is_node_state_success(task.state):
              successful_node_ids.add(node_id)
            elif pstate.is_node_state_failure(task.state):
              failed_nodes_dict[node_id] = task.status
              if self._fail_fast:
                finalize_pipeline_task = self._abort_task(task.status.message)
            update_node_state_tasks.append(task)
          elif task_lib.is_exec_node_task(task):
            exec_node_tasks.append(task)

        if finalize_pipeline_task:
          break

      if finalize_pipeline_task:
        break

    if not self._fail_fast and failed_nodes_dict:
      assert not finalize_pipeline_task
      node_by_id = _node_by_id(self._pipeline)
      # Collect nodes that cannot be run because they have a failed ancestor.
      unrunnable_node_ids = set()
      for node_id in failed_nodes_dict:
        unrunnable_node_ids |= _descendants(node_by_id, node_id)
      # Nodes that are still runnable have neither succeeded nor failed, and
      # don't have a failed ancestor.
      runnable_node_ids = node_by_id.keys() - (
          unrunnable_node_ids | successful_node_ids | failed_nodes_dict.keys())
      # If there are no runnable nodes, we can abort the pipeline.
      if not runnable_node_ids:
        finalize_pipeline_task = self._abort_task(
            f'Cannot make progress due to node failures: {failed_nodes_dict}')

    result = update_node_state_tasks
    if finalize_pipeline_task:
      result.append(finalize_pipeline_task)
    elif terminal_node_ids <= successful_node_ids:
      # If all terminal nodes are successful, the pipeline can be finalized.
      result.append(
          task_lib.FinalizePipelineTask(
              pipeline_uid=self._pipeline_uid,
              status=status_lib.Status(code=status_lib.Code.OK)))
    else:
      result.extend(exec_node_tasks)
    return result