in tfx/orchestration/experimental/core/pipeline_ops.py [0:0]
def _orchestrate_active_pipeline(
mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
service_job_manager: service_jobs.ServiceJobManager,
pipeline_state: pstate.PipelineState) -> None:
"""Orchestrates active pipeline."""
pipeline = pipeline_state.pipeline
with pipeline_state:
assert pipeline_state.is_active()
if pipeline_state.get_pipeline_execution_state() != (
metadata_store_pb2.Execution.RUNNING):
pipeline_state.set_pipeline_execution_state(
metadata_store_pb2.Execution.RUNNING)
orchestration_options = pipeline_state.get_orchestration_options()
logging.info('Orchestration options: %s', orchestration_options)
deadline_secs = orchestration_options.deadline_secs
if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and
deadline_secs > 0 and
time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() >
deadline_secs):
logging.error(
'Aborting pipeline due to exceeding deadline (%s secs); '
'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
pipeline_state.initiate_stop(
status_lib.Status(
code=status_lib.Code.DEADLINE_EXCEEDED,
message=('Pipeline aborted due to exceeding deadline '
f'({deadline_secs} secs)')))
return
def _filter_by_state(node_infos: List[_NodeInfo],
state_str: str) -> List[_NodeInfo]:
return [n for n in node_infos if n.state.state == state_str]
node_infos = _get_node_infos(pipeline_state)
stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING)
# Tracks nodes stopped in the current iteration.
stopped_node_infos: List[_NodeInfo] = []
# Create cancellation tasks for nodes in state STOPPING.
for node_info in stopping_node_infos:
if service_job_manager.is_pure_service_node(pipeline_state,
node_info.node.node_info.id):
if service_job_manager.stop_node_services(pipeline_state,
node_info.node.node_info.id):
stopped_node_infos.append(node_info)
elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node,
task_queue):
pass
elif service_job_manager.is_mixed_service_node(pipeline_state,
node_info.node.node_info.id):
if service_job_manager.stop_node_services(pipeline_state,
node_info.node.node_info.id):
stopped_node_infos.append(node_info)
else:
stopped_node_infos.append(node_info)
# Change the state of stopped nodes from STOPPING to STOPPED.
if stopped_node_infos:
with pipeline_state:
for node_info in stopped_node_infos:
node_uid = task_lib.NodeUid.from_pipeline_node(pipeline, node_info.node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
node_state.update(pstate.NodeState.STOPPED, node_state.status)
# Initialize task generator for the pipeline.
if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
mlmd_handle,
task_queue.contains_task_id,
service_job_manager,
fail_fast=orchestration_options.fail_fast)
elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
mlmd_handle, task_queue.contains_task_id, service_job_manager)
else:
raise status_lib.StatusNotOkError(
code=status_lib.Code.FAILED_PRECONDITION,
message=(
f'Only SYNC and ASYNC pipeline execution modes supported; '
f'found pipeline with execution mode: {pipeline.execution_mode}'))
tasks = generator.generate(pipeline_state)
with pipeline_state:
# Handle all the UpdateNodeStateTasks by updating node states.
for task in tasks:
if task_lib.is_update_node_state_task(task):
task = typing.cast(task_lib.UpdateNodeStateTask, task)
with pipeline_state.node_state_update_context(
task.node_uid) as node_state:
node_state.update(task.state, task.status)
tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]
# If there are still nodes in state STARTING, change them to STARTED.
for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
node_uid = task_lib.NodeUid.from_pipeline_node(pipeline_state.pipeline,
node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.state == pstate.NodeState.STARTING:
node_state.update(pstate.NodeState.STARTED)
for task in tasks:
if task_lib.is_exec_node_task(task):
task = typing.cast(task_lib.ExecNodeTask, task)
task_queue.enqueue(task)
else:
assert task_lib.is_finalize_pipeline_task(task)
assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
assert len(tasks) == 1
task = typing.cast(task_lib.FinalizePipelineTask, task)
if task.status.code == status_lib.Code.OK:
logging.info('Pipeline run successful; pipeline uid: %s',
pipeline_state.pipeline_uid)
else:
logging.info('Pipeline run failed; pipeline uid: %s',
pipeline_state.pipeline_uid)
pipeline_state.initiate_stop(task.status)