def _orchestrate_active_pipeline()

in tfx/orchestration/experimental/core/pipeline_ops.py [0:0]


def _orchestrate_active_pipeline(
    mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
    service_job_manager: service_jobs.ServiceJobManager,
    pipeline_state: pstate.PipelineState) -> None:
  """Orchestrates active pipeline."""
  pipeline = pipeline_state.pipeline
  with pipeline_state:
    assert pipeline_state.is_active()
    if pipeline_state.get_pipeline_execution_state() != (
        metadata_store_pb2.Execution.RUNNING):
      pipeline_state.set_pipeline_execution_state(
          metadata_store_pb2.Execution.RUNNING)
    orchestration_options = pipeline_state.get_orchestration_options()
    logging.info('Orchestration options: %s', orchestration_options)
    deadline_secs = orchestration_options.deadline_secs
    if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and
        deadline_secs > 0 and
        time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() >
        deadline_secs):
      logging.error(
          'Aborting pipeline due to exceeding deadline (%s secs); '
          'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
      pipeline_state.initiate_stop(
          status_lib.Status(
              code=status_lib.Code.DEADLINE_EXCEEDED,
              message=('Pipeline aborted due to exceeding deadline '
                       f'({deadline_secs} secs)')))
      return

  def _filter_by_state(node_infos: List[_NodeInfo],
                       state_str: str) -> List[_NodeInfo]:
    return [n for n in node_infos if n.state.state == state_str]

  node_infos = _get_node_infos(pipeline_state)
  stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING)

  # Tracks nodes stopped in the current iteration.
  stopped_node_infos: List[_NodeInfo] = []

  # Create cancellation tasks for nodes in state STOPPING.
  for node_info in stopping_node_infos:
    if service_job_manager.is_pure_service_node(pipeline_state,
                                                node_info.node.node_info.id):
      if service_job_manager.stop_node_services(pipeline_state,
                                                node_info.node.node_info.id):
        stopped_node_infos.append(node_info)
    elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node,
                                          task_queue):
      pass
    elif service_job_manager.is_mixed_service_node(pipeline_state,
                                                   node_info.node.node_info.id):
      if service_job_manager.stop_node_services(pipeline_state,
                                                node_info.node.node_info.id):
        stopped_node_infos.append(node_info)
    else:
      stopped_node_infos.append(node_info)

  # Change the state of stopped nodes from STOPPING to STOPPED.
  if stopped_node_infos:
    with pipeline_state:
      for node_info in stopped_node_infos:
        node_uid = task_lib.NodeUid.from_pipeline_node(pipeline, node_info.node)
        with pipeline_state.node_state_update_context(node_uid) as node_state:
          node_state.update(pstate.NodeState.STOPPED, node_state.status)

  # Initialize task generator for the pipeline.
  if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
    generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
        mlmd_handle,
        task_queue.contains_task_id,
        service_job_manager,
        fail_fast=orchestration_options.fail_fast)
  elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
    generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
        mlmd_handle, task_queue.contains_task_id, service_job_manager)
  else:
    raise status_lib.StatusNotOkError(
        code=status_lib.Code.FAILED_PRECONDITION,
        message=(
            f'Only SYNC and ASYNC pipeline execution modes supported; '
            f'found pipeline with execution mode: {pipeline.execution_mode}'))

  tasks = generator.generate(pipeline_state)

  with pipeline_state:
    # Handle all the UpdateNodeStateTasks by updating node states.
    for task in tasks:
      if task_lib.is_update_node_state_task(task):
        task = typing.cast(task_lib.UpdateNodeStateTask, task)
        with pipeline_state.node_state_update_context(
            task.node_uid) as node_state:
          node_state.update(task.state, task.status)

    tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

    # If there are still nodes in state STARTING, change them to STARTED.
    for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
      node_uid = task_lib.NodeUid.from_pipeline_node(pipeline_state.pipeline,
                                                     node)
      with pipeline_state.node_state_update_context(node_uid) as node_state:
        if node_state.state == pstate.NodeState.STARTING:
          node_state.update(pstate.NodeState.STARTED)

    for task in tasks:
      if task_lib.is_exec_node_task(task):
        task = typing.cast(task_lib.ExecNodeTask, task)
        task_queue.enqueue(task)
      else:
        assert task_lib.is_finalize_pipeline_task(task)
        assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
        assert len(tasks) == 1
        task = typing.cast(task_lib.FinalizePipelineTask, task)
        if task.status.code == status_lib.Code.OK:
          logging.info('Pipeline run successful; pipeline uid: %s',
                       pipeline_state.pipeline_uid)
        else:
          logging.info('Pipeline run failed; pipeline uid: %s',
                       pipeline_state.pipeline_uid)
        pipeline_state.initiate_stop(task.status)