pai/pipeline/run.py (399 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Callable, Optional
from ..api.base import PaginatedResult
from ..common.logging import get_logger
from ..exception import PAIException
from ..session import Session, get_default_session
from .artifact import ArchivedArtifact
logger = get_logger(__name__)
# TODO: review the status names of the PipelineRun.
class PipelineRunStatus(object):
Initialized = "Initialized"
ReadyToSchedule = "ReadyToSchedule"
Starting = "Starting"
Running = "Running"
WorkflowServiceStarting = "WorkflowServiceStarting"
Suspended = "Suspended"
Succeeded = "Succeeded"
Terminated = "Terminated"
Unknown = "Unknown"
Skipped = "Skipped"
Failed = "Failed"
@classmethod
def completed_status(cls):
return [
cls.Suspended,
cls.Terminated,
cls.Skipped,
cls.Failed,
]
@classmethod
def is_running(cls, status):
if status in (
cls.Starting,
cls.Running,
cls.WorkflowServiceStarting,
cls.ReadyToSchedule,
):
return True
return False
class PipelineRun(object):
"""Class represent a pipeline run resource."""
def __init__(
self,
run_id,
name=None,
workspace_id=None,
status=None,
node_id=None,
duration=None,
started_at=None,
finished_at=None,
source=None,
user_id=None,
parent_user_id=None,
session=None,
):
self.run_id = run_id
self.name = name
self.workspace_id = workspace_id
self.status = status
self.node_id = node_id
self.duration = duration
self.started_at = started_at
self.finished_at = finished_at
self.source = source
self.user_id = user_id
self.parent_user_id = parent_user_id
self.session = session or get_default_session()
@classmethod
def get(cls, run_id, session=None) -> "PipelineRun":
session = session or get_default_session()
return cls.deserialize(session.pipeline_run_api.get(run_id=run_id))
@classmethod
def run(
cls,
name,
arguments,
env=None,
pipeline_id: Optional[str] = None,
manifest: Optional[str] = None,
no_confirm_required: bool = True,
session: Optional[Session] = None,
):
"""Submit a pipeline run with pipeline operator and run arguments.
If pipeline_id is supplied, remote pipeline manifest is used as workflow template.
Args:
name (str): PipelineRun instance name of the submitted job.
arguments (dict): Run arguments required by pipeline manifest.
env (list): Environment arguments of run.
pipeline_id (str): Pipeline
manifest (str or dict): Pipeline manifest of the run workflow.
no_confirm_required (bool): Run workflow start immediately if true
else start_run service call if required to start the workflow.
session (:class:`pai.session.Session`): A PAI session instance used for
communicating with PAI service.
Returns:
str:run id if run workflow init success.
"""
session = session or get_default_session()
run_id = session.pipeline_run_api.create(
name=name,
arguments=arguments,
env=env,
manifest=manifest,
pipeline_id=pipeline_id,
no_confirm_required=no_confirm_required,
)
run = PipelineRun.get(run_id)
logger.info(
"Create pipeline run succeeded (run_id: {run_id}), please visit the link"
" below to view the run details.".format(run_id=run_id)
)
logger.info(run.console_uri)
return run_id
@classmethod
def list(
cls,
name=None,
run_id=None,
pipeline_id=None,
status=None,
sort_by=None,
order=None,
page_size=20,
page_number=1,
session=None,
**kwargs,
):
session = session or get_default_session()
result = session.pipeline_run_api.list(
name=name,
run_id=run_id,
pipeline_id=pipeline_id,
status=status,
sort_by=sort_by,
order=order,
workspace_id=None,
page_size=page_size,
page_number=page_number,
**kwargs,
)
return [cls.deserialize(run) for run in result.items]
@classmethod
def deserialize(cls, d):
return cls(
run_id=d["RunId"],
node_id=d["NodeId"],
name=d["Name"],
workspace_id=d["WorkspaceId"],
user_id=d.get("UserId"),
parent_user_id=d.get("ParentUserId"),
source=d.get("Source"),
started_at=d.get("StartedAt"),
status=d.get("Status"),
)
def __repr__(self):
return "PipelineRun:%s" % self.run_id
def travel_node_status_info(self, node_id, max_depth=10):
node_status_info = dict()
def pipelines_travel(curr_node_id, parent=None, cur_depth=1):
if cur_depth > max_depth:
return
run_node_detail_info = self.session.pipeline_run_api.get_node(
self.run_id,
curr_node_id,
depth=2,
)
if (
not run_node_detail_info
or "StartedAt" not in run_node_detail_info["StatusInfo"]
):
return
if parent is None:
curr_root_name = self.name
else:
curr_root_name = "{0}.{1}".format(
run_node_detail_info["Metadata"]["Name"], parent
)
node_status_info[curr_root_name] = self._pipeline_node_info(
run_node_detail_info
)
pipelines = run_node_detail_info["Spec"].get("Pipelines", [])
if not pipelines:
return
for sub_pipeline in pipelines:
node_name = "{0}.{1}".format(
curr_root_name, sub_pipeline["Metadata"]["Name"]
)
node_status_info[node_name] = self._pipeline_node_info(sub_pipeline)
next_node_id = sub_pipeline["Metadata"]["NodeId"]
if sub_pipeline["Metadata"]["NodeType"] == "Dag" and next_node_id:
pipelines_travel(next_node_id, curr_root_name, cur_depth + 1)
pipelines_travel(curr_node_id=node_id)
return node_status_info
@staticmethod
def _pipeline_node_info(pipeline_info):
return {
"name": pipeline_info["Metadata"]["Name"],
"nodeId": pipeline_info["Metadata"]["NodeId"],
"status": pipeline_info["StatusInfo"]["Status"],
"startedAt": pipeline_info["StatusInfo"]["StartedAt"],
"finishedAt": pipeline_info["StatusInfo"].get("FinishedAt", None),
}
@property
def console_uri(self):
if not self.session.is_inner:
return "{console_host}?regionId={region_id}#/studio/task/detail/{run_id}".format(
console_host=self.session.console_uri,
region_id=self.session.region_id,
run_id=self.run_id,
)
return "{console_host}/#/studio/task/detail/{run_id}".format(
console_host=self.session.console_uri, run_id=self.run_id
)
def get_run_info(self):
return self.session.pipeline_run_api.get(self.run_id)
def get_run_node_detail(self, node_id, depth=2):
return self.session.pipeline_run_api.get_node(
self.run_id, node_id=node_id, depth=depth
)
def get_outputs(self, name=None, node_id=None, depth=1, type=None):
if not node_id:
run_info = self.get_run_info()
node_id = run_info["NodeId"]
if not node_id:
return
result = self.session.pipeline_run_api.list_node_outputs(
name=name,
node_id=node_id,
run_id=self.run_id,
depth=depth,
type=type,
)
return [ArchivedArtifact.deserialize(output) for output in result.items]
def get_status(self):
return self.get_run_info()["Status"]
def start(self):
self.session.pipeline_run_api.start(self.run_id)
def terminate(self):
self.session.pipeline_run_api.terminate(self.run_id)
def _wait_for_init(self, retry_interval=1):
"""Wait for "NodeId" allocated to pipeline run."""
datetime.now()
run_info = self.get_run_info()
while (
PipelineRunStatus.is_running(run_info["Status"]) and not run_info["NodeId"]
):
time.sleep(retry_interval)
run_info = self.get_run_info()
if run_info.get("NodeId", None):
return run_info["NodeId"]
else:
raise ValueError("Failed in acquire root node_id of pipeline run.")
def wait_for_completion(self, show_outputs=True):
"""Wait until the pipeline run stop."""
run_info = self.get_run_info()
node_id = run_info["NodeId"]
if not node_id:
raise ValueError("Expect NodeId in GetRun response")
run_status = run_info["Status"]
if run_status == PipelineRunStatus.Initialized:
raise ValueError(
'Pipeline run instance is in status "Init", please start the run instance.'
)
elif run_status in (PipelineRunStatus.Terminated, PipelineRunStatus.Suspended):
raise ValueError(
"Pipeline run instance is stopped(status:%s), please resume/retry the run."
% run_status
)
elif run_status == PipelineRunStatus.Failed:
raise ValueError("Pipeline run is failed.")
elif run_status in (PipelineRunStatus.Skipped, PipelineRunStatus.Unknown):
raise ValueError(
"Pipeline run in unexpected status(%s:%s)" % (self.run_id, run_status)
)
elif run_status == PipelineRunStatus.Succeeded:
return
# Wait for Workflow init.
print("Wait for run workflow init")
if show_outputs:
run_logger = _RunLogger(
run_instance=self, node_id=node_id, session=self.session
)
else:
run_logger = _MockRunLogger(run_instance=self, node_id=node_id)
try:
prev_status_infos = {}
root_node_status = run_status
log_runners = []
while PipelineRunStatus.is_running(root_node_status):
curr_status_infos = self.travel_node_status_info(node_id)
for node_fullname, status_info in curr_status_infos.items():
if (
node_fullname not in prev_status_infos
and status_info["status"] != PipelineRunStatus.Skipped
):
log_runner = run_logger.submit(
node_id=status_info["nodeId"], node_name=node_fullname
)
if log_runner:
log_runners.append(log_runner)
prev_status_infos = curr_status_infos
root_node_status = (
curr_status_infos[self.name]["status"]
if self.name in curr_status_infos
else root_node_status
)
if root_node_status == PipelineRunStatus.Failed:
raise PAIException(
"PipelineRun failed: run_id={}, run_status_info={}".format(
self.run_id, curr_status_infos
)
)
failed_nodes = {
name: status_info
for name, status_info in curr_status_infos.items()
if PipelineRunStatus.Failed == status_info["status"]
}
if failed_nodes:
raise PAIException(
"PipelineRun failed: run_id={}, failed_nodes={}".format(
self.run_id, failed_nodes
)
)
time.sleep(2)
except (KeyboardInterrupt, PAIException) as e:
run_logger.stop_tail()
raise e
for log_runner in log_runners:
_ = log_runner.result()
return self
def _wait_with_progress(self):
pass
def _wait_with_logger(self, node_id):
pass
def make_log_iterator(method: Callable, **kwargs):
"""Make an iterator from resource list API.
Args:
method: Resource List API.
**kwargs: arguments for the method.
Returns:
A resource iterator.
"""
page_offset = kwargs.get("page_offset", 0)
page_size = kwargs.get("page_size", 20)
while True:
kwargs.update(page_offset=page_offset, page_size=page_size)
result: PaginatedResult = method(**kwargs)
for item in result.items:
yield item
if len(result.items) == 0 or len(result.items) <= page_size:
return
page_offset += page_size
class _RunLogger(object):
executor = ThreadPoolExecutor(5)
def __init__(self, run_instance, node_id, session):
super(_RunLogger, self).__init__()
self.run_instance = run_instance
self.node_id = node_id
self.running_nodes = set()
self.session = session
self._tail = True
def tail(
self,
node_id,
node_name,
page_size=100,
page_offset=0,
):
if node_id in self.running_nodes:
return
self.running_nodes.add(node_id)
while True and self._tail:
logs = make_log_iterator(
self.session.pipeline_run_api.list_node_logs,
run_id=self.run_instance.run_id,
node_id=node_id,
page_size=page_size,
page_offset=page_offset,
)
count = 0
for log in logs:
print("%s: %s" % (node_name, log))
page_offset += 1
count += 1
if count % page_size == 0:
time.sleep(0.5)
if count == 0:
status = self.run_instance.get_status()
if PipelineRunStatus.is_running(status):
time.sleep(2)
else:
break
def submit(
self,
node_id,
node_name,
page_size=100,
page_offset=0,
):
print("Add Node Logger: {}, {}".format(node_name, node_id))
if node_id in self.running_nodes:
return
return self.executor.submit(
self.tail,
node_id=node_id,
node_name=node_name,
page_size=page_size,
page_offset=page_offset,
)
def stop_tail(self):
self._tail = False
class _MockRunLogger(object):
def __init__(self, run_instance, node_id):
super(_MockRunLogger, self).__init__()
self.run_instance = run_instance
self.node_id = node_id
def tail(self, **kwargs):
pass
def submit(self, *args, **kwargs):
pass
def stop_tail(self):
pass