core/maxframe_client/session/task.py (259 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import logging
import time
from typing import Any, Dict, List, Optional, Type, Union
import msgpack
from odps import ODPS
from odps import options as odps_options
from odps.errors import EmptyTaskInfoError, parse_instance_error
from odps.models import Instance, MaxFrameTask
from maxframe.config import options
from maxframe.core import TileableGraph
from maxframe.errors import NoTaskServerResponseError, SessionAlreadyClosedError
from maxframe.protocol import DagInfo, JsonSerializable, ResultInfo, SessionInfo
from maxframe.utils import deserialize_serializable, serialize_serializable, to_str
try:
from maxframe import __version__ as mf_version
except ImportError:
mf_version = None
from .consts import (
EMPTY_RESPONSE_RETRY_COUNT,
MAXFRAME_DEFAULT_PROTOCOL,
MAXFRAME_OUTPUT_JSON_FORMAT,
MAXFRAME_OUTPUT_MAXFRAME_FORMAT,
MAXFRAME_OUTPUT_MSGPACK_FORMAT,
MAXFRAME_TASK_CANCEL_DAG_METHOD,
MAXFRAME_TASK_CREATE_SESSION_METHOD,
MAXFRAME_TASK_DECREF_METHOD,
MAXFRAME_TASK_DELETE_SESSION_METHOD,
MAXFRAME_TASK_GET_DAG_INFO_METHOD,
MAXFRAME_TASK_GET_SESSION_METHOD,
MAXFRAME_TASK_SUBMIT_DAG_METHOD,
ODPS_SESSION_INSECURE_SCHEME,
ODPS_SESSION_SECURE_SCHEME,
)
from .odps import MaxFrameServiceCaller, MaxFrameSession
logger = logging.getLogger(__name__)
class MaxFrameInstanceCaller(MaxFrameServiceCaller):
_instance: Optional[Instance]
def __init__(
self,
odps_entry: ODPS,
task_name: Optional[str] = None,
project: Optional[str] = None,
priority: Optional[int] = None,
running_cluster: Optional[str] = None,
nested_instance_id: Optional[str] = None,
major_version: Optional[str] = None,
output_format: Optional[str] = None,
**kwargs,
):
if callable(odps_options.get_priority):
default_priority = odps_options.get_priority(odps_entry)
else:
default_priority = odps_options.priority
priority = priority if priority is not None else default_priority
self._odps_entry = odps_entry
self._task_name = task_name
self._project = project
self._priority = priority
self._running_cluster = running_cluster
self._major_version = major_version
self._output_format = output_format or MAXFRAME_OUTPUT_MSGPACK_FORMAT
self._deleted = False
if nested_instance_id is None:
self._nested = False
self._instance = None
else:
self._nested = True
self._instance = odps_entry.get_instance(nested_instance_id)
@property
def instance(self):
return self._instance
def _deserial_task_info_result(
self, content: Union[bytes, str, dict], target_cls: Type[JsonSerializable]
):
if isinstance(content, (str, bytes)):
if len(content) == 0:
content = "{}"
json_data = json.loads(to_str(content))
else:
json_data = content
encoded_result = json_data.get("result")
if not encoded_result:
if self._deleted:
return None
else:
raise SessionAlreadyClosedError(self._instance.id)
result_data = base64.b64decode(encoded_result)
if self._output_format == MAXFRAME_OUTPUT_MAXFRAME_FORMAT:
return deserialize_serializable(result_data)
elif self._output_format == MAXFRAME_OUTPUT_JSON_FORMAT:
return target_cls.from_json(json.loads(result_data))
elif self._output_format == MAXFRAME_OUTPUT_MSGPACK_FORMAT:
return target_cls.from_json(msgpack.loads(result_data))
else:
raise ValueError(
f"Serialization format {self._output_format} not supported"
)
def _create_maxframe_task(self) -> MaxFrameTask:
task = MaxFrameTask(name=self._task_name, major_version=self._major_version)
mf_settings = self.get_settings_to_upload()
mf_opts = {
"odps.maxframe.settings": json.dumps(mf_settings),
"odps.maxframe.output_format": self._output_format,
}
if mf_settings.get("session.quota_name", None):
mf_opts["odps.task.wlm.quota"] = mf_settings["session.quota_name"]
if mf_version:
mf_opts["odps.maxframe.client_version"] = mf_version
task.update_settings(mf_opts)
return task
def create_session(self) -> SessionInfo:
task = self._create_maxframe_task()
if not self._nested:
self._task_name = task.name
project = self._odps_entry.get_project(self._project)
self._instance = project.instances.create(
task=task,
priority=self._priority,
running_cluster=self._running_cluster,
)
self._wait_instance_task_ready()
return self.get_session()
else:
result = self._instance.put_task_info(
self._task_name,
MAXFRAME_TASK_CREATE_SESSION_METHOD,
task.properties["settings"],
)
return self._deserial_task_info_result(result, SessionInfo)
def _parse_instance_result_error(self):
result_data = self._instance.get_task_result(self._task_name)
try:
info = self._deserial_task_info_result({"result": result_data}, SessionInfo)
except:
raise parse_instance_error(result_data)
info.error_info.reraise()
def _wait_instance_task_ready(
self, interval: float = 0.1, max_interval: float = 5.0, timeout: int = None
):
check_time = time.time()
timeout = timeout or options.client.task_start_timeout
while True:
if self._instance.is_terminated(retry=True):
self._parse_instance_result_error()
status_json = json.loads(
self._instance.get_task_info(self._task_name, "status") or "{}"
)
if status_json.get("status") == "Running":
break
if time.time() - check_time > timeout:
raise TimeoutError("Check session startup time out")
time.sleep(interval)
interval = min(max_interval, interval * 2)
def _put_task_info(self, method_name: str, json_data: dict):
for trial in range(EMPTY_RESPONSE_RETRY_COUNT):
try:
return self._instance.put_task_info(
self._task_name,
method_name,
json.dumps(json_data),
raise_empty=True,
)
except EmptyTaskInfoError as ex:
# retry when server returns HTTP 204, which is designed for retry
if ex.code != 204 or trial >= EMPTY_RESPONSE_RETRY_COUNT - 1:
raise NoTaskServerResponseError(
f"No response for request {method_name}. "
f"Instance ID: {self._instance.id}. "
f"Request ID: {ex.request_id}"
) from None
time.sleep(0.5)
def get_session(self) -> SessionInfo:
req_data = {"output_format": self._output_format}
serialized = self._put_task_info(MAXFRAME_TASK_GET_SESSION_METHOD, req_data)
info: SessionInfo = self._deserial_task_info_result(serialized, SessionInfo)
info.session_id = self._instance.id
return info
def delete_session(self) -> None:
if not self._nested:
self._instance.stop()
else:
req_data = {"output_format": self._output_format}
self._put_task_info(MAXFRAME_TASK_DELETE_SESSION_METHOD, req_data)
self._deleted = True
def submit_dag(
self,
dag: TileableGraph,
managed_input_infos: Optional[Dict[str, ResultInfo]] = None,
new_settings: Dict[str, Any] = None,
) -> DagInfo:
new_settings_value = {
"odps.maxframe.settings": json.dumps(new_settings),
}
req_data = {
"protocol": MAXFRAME_DEFAULT_PROTOCOL,
"dag": base64.b64encode(serialize_serializable(dag)).decode(),
"managed_input_infos": base64.b64encode(
serialize_serializable(managed_input_infos)
).decode(),
"new_settings": json.dumps(new_settings_value),
"output_format": self._output_format,
}
res = self._put_task_info(MAXFRAME_TASK_SUBMIT_DAG_METHOD, req_data)
return self._deserial_task_info_result(res, DagInfo)
def get_dag_info(self, dag_id: str) -> DagInfo:
req_data = {
"protocol": MAXFRAME_DEFAULT_PROTOCOL,
"dag_id": dag_id,
"output_format": self._output_format,
}
res = self._put_task_info(MAXFRAME_TASK_GET_DAG_INFO_METHOD, req_data)
return self._deserial_task_info_result(res, DagInfo)
def cancel_dag(self, dag_id: str) -> DagInfo:
req_data = {
"protocol": MAXFRAME_DEFAULT_PROTOCOL,
"dag_id": dag_id,
"output_format": self._output_format,
}
res = self._put_task_info(MAXFRAME_TASK_CANCEL_DAG_METHOD, req_data)
return self._deserial_task_info_result(res, DagInfo)
def decref(self, tileable_keys: List[str]) -> None:
req_data = {
"tileable_keys": ",".join(tileable_keys),
}
self._put_task_info(MAXFRAME_TASK_DECREF_METHOD, req_data)
def get_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
"""
Generate logview address
Parameters
----------
dag_id: id of dag for which dag logview detail page to access
hours: hours of the logview address auth limit
Returns
-------
Logview address
"""
hours = hours or options.session.logview_hours
# notice: maxframe can't reuse subQuery else will conflict with mcqa when fetch resource data,
# added dagId for maxframe so logview backend will return maxframe data format if
# instance and dagId is provided.
dag_suffix = f"&dagId={dag_id}" if dag_id else ""
return self._instance.get_logview_address(hours) + dag_suffix
class MaxFrameTaskSession(MaxFrameSession):
schemes = [ODPS_SESSION_INSECURE_SCHEME, ODPS_SESSION_SECURE_SCHEME]
_caller: MaxFrameInstanceCaller
@classmethod
def _create_caller(
cls,
odps_entry: ODPS,
address: str,
priority: Optional[int] = None,
project: Optional[str] = None,
running_cluster: Optional[str] = None,
**kwargs,
) -> MaxFrameServiceCaller:
return MaxFrameInstanceCaller(
odps_entry,
priority=priority,
running_cluster=running_cluster,
project=project,
**kwargs,
)
@property
def closed(self) -> bool:
if super().closed:
return True
if not self._caller or not self._caller.instance:
# session not initialized yet
return False
return self._caller.instance.is_terminated()
def register_session_schemes(overwrite: bool = False):
MaxFrameTaskSession.register_schemes(overwrite=overwrite)