core/maxframe_client/session/task.py (259 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import json import logging import time from typing import Any, Dict, List, Optional, Type, Union import msgpack from odps import ODPS from odps import options as odps_options from odps.errors import EmptyTaskInfoError, parse_instance_error from odps.models import Instance, MaxFrameTask from maxframe.config import options from maxframe.core import TileableGraph from maxframe.errors import NoTaskServerResponseError, SessionAlreadyClosedError from maxframe.protocol import DagInfo, JsonSerializable, ResultInfo, SessionInfo from maxframe.utils import deserialize_serializable, serialize_serializable, to_str try: from maxframe import __version__ as mf_version except ImportError: mf_version = None from .consts import ( EMPTY_RESPONSE_RETRY_COUNT, MAXFRAME_DEFAULT_PROTOCOL, MAXFRAME_OUTPUT_JSON_FORMAT, MAXFRAME_OUTPUT_MAXFRAME_FORMAT, MAXFRAME_OUTPUT_MSGPACK_FORMAT, MAXFRAME_TASK_CANCEL_DAG_METHOD, MAXFRAME_TASK_CREATE_SESSION_METHOD, MAXFRAME_TASK_DECREF_METHOD, MAXFRAME_TASK_DELETE_SESSION_METHOD, MAXFRAME_TASK_GET_DAG_INFO_METHOD, MAXFRAME_TASK_GET_SESSION_METHOD, MAXFRAME_TASK_SUBMIT_DAG_METHOD, ODPS_SESSION_INSECURE_SCHEME, ODPS_SESSION_SECURE_SCHEME, ) from .odps import MaxFrameServiceCaller, MaxFrameSession logger = logging.getLogger(__name__) class MaxFrameInstanceCaller(MaxFrameServiceCaller): _instance: Optional[Instance] def __init__( self, odps_entry: ODPS, task_name: Optional[str] = None, project: Optional[str] = None, priority: Optional[int] = None, running_cluster: Optional[str] = None, nested_instance_id: Optional[str] = None, major_version: Optional[str] = None, output_format: Optional[str] = None, **kwargs, ): if callable(odps_options.get_priority): default_priority = odps_options.get_priority(odps_entry) else: default_priority = odps_options.priority priority = priority if priority is not None else default_priority self._odps_entry = odps_entry self._task_name = task_name self._project = project self._priority = priority self._running_cluster = running_cluster self._major_version = major_version self._output_format = output_format or MAXFRAME_OUTPUT_MSGPACK_FORMAT self._deleted = False if nested_instance_id is None: self._nested = False self._instance = None else: self._nested = True self._instance = odps_entry.get_instance(nested_instance_id) @property def instance(self): return self._instance def _deserial_task_info_result( self, content: Union[bytes, str, dict], target_cls: Type[JsonSerializable] ): if isinstance(content, (str, bytes)): if len(content) == 0: content = "{}" json_data = json.loads(to_str(content)) else: json_data = content encoded_result = json_data.get("result") if not encoded_result: if self._deleted: return None else: raise SessionAlreadyClosedError(self._instance.id) result_data = base64.b64decode(encoded_result) if self._output_format == MAXFRAME_OUTPUT_MAXFRAME_FORMAT: return deserialize_serializable(result_data) elif self._output_format == MAXFRAME_OUTPUT_JSON_FORMAT: return target_cls.from_json(json.loads(result_data)) elif self._output_format == MAXFRAME_OUTPUT_MSGPACK_FORMAT: return target_cls.from_json(msgpack.loads(result_data)) else: raise ValueError( f"Serialization format {self._output_format} not supported" ) def _create_maxframe_task(self) -> MaxFrameTask: task = MaxFrameTask(name=self._task_name, major_version=self._major_version) mf_settings = self.get_settings_to_upload() mf_opts = { "odps.maxframe.settings": json.dumps(mf_settings), "odps.maxframe.output_format": self._output_format, } if mf_settings.get("session.quota_name", None): mf_opts["odps.task.wlm.quota"] = mf_settings["session.quota_name"] if mf_version: mf_opts["odps.maxframe.client_version"] = mf_version task.update_settings(mf_opts) return task def create_session(self) -> SessionInfo: task = self._create_maxframe_task() if not self._nested: self._task_name = task.name project = self._odps_entry.get_project(self._project) self._instance = project.instances.create( task=task, priority=self._priority, running_cluster=self._running_cluster, ) self._wait_instance_task_ready() return self.get_session() else: result = self._instance.put_task_info( self._task_name, MAXFRAME_TASK_CREATE_SESSION_METHOD, task.properties["settings"], ) return self._deserial_task_info_result(result, SessionInfo) def _parse_instance_result_error(self): result_data = self._instance.get_task_result(self._task_name) try: info = self._deserial_task_info_result({"result": result_data}, SessionInfo) except: raise parse_instance_error(result_data) info.error_info.reraise() def _wait_instance_task_ready( self, interval: float = 0.1, max_interval: float = 5.0, timeout: int = None ): check_time = time.time() timeout = timeout or options.client.task_start_timeout while True: if self._instance.is_terminated(retry=True): self._parse_instance_result_error() status_json = json.loads( self._instance.get_task_info(self._task_name, "status") or "{}" ) if status_json.get("status") == "Running": break if time.time() - check_time > timeout: raise TimeoutError("Check session startup time out") time.sleep(interval) interval = min(max_interval, interval * 2) def _put_task_info(self, method_name: str, json_data: dict): for trial in range(EMPTY_RESPONSE_RETRY_COUNT): try: return self._instance.put_task_info( self._task_name, method_name, json.dumps(json_data), raise_empty=True, ) except EmptyTaskInfoError as ex: # retry when server returns HTTP 204, which is designed for retry if ex.code != 204 or trial >= EMPTY_RESPONSE_RETRY_COUNT - 1: raise NoTaskServerResponseError( f"No response for request {method_name}. " f"Instance ID: {self._instance.id}. " f"Request ID: {ex.request_id}" ) from None time.sleep(0.5) def get_session(self) -> SessionInfo: req_data = {"output_format": self._output_format} serialized = self._put_task_info(MAXFRAME_TASK_GET_SESSION_METHOD, req_data) info: SessionInfo = self._deserial_task_info_result(serialized, SessionInfo) info.session_id = self._instance.id return info def delete_session(self) -> None: if not self._nested: self._instance.stop() else: req_data = {"output_format": self._output_format} self._put_task_info(MAXFRAME_TASK_DELETE_SESSION_METHOD, req_data) self._deleted = True def submit_dag( self, dag: TileableGraph, managed_input_infos: Optional[Dict[str, ResultInfo]] = None, new_settings: Dict[str, Any] = None, ) -> DagInfo: new_settings_value = { "odps.maxframe.settings": json.dumps(new_settings), } req_data = { "protocol": MAXFRAME_DEFAULT_PROTOCOL, "dag": base64.b64encode(serialize_serializable(dag)).decode(), "managed_input_infos": base64.b64encode( serialize_serializable(managed_input_infos) ).decode(), "new_settings": json.dumps(new_settings_value), "output_format": self._output_format, } res = self._put_task_info(MAXFRAME_TASK_SUBMIT_DAG_METHOD, req_data) return self._deserial_task_info_result(res, DagInfo) def get_dag_info(self, dag_id: str) -> DagInfo: req_data = { "protocol": MAXFRAME_DEFAULT_PROTOCOL, "dag_id": dag_id, "output_format": self._output_format, } res = self._put_task_info(MAXFRAME_TASK_GET_DAG_INFO_METHOD, req_data) return self._deserial_task_info_result(res, DagInfo) def cancel_dag(self, dag_id: str) -> DagInfo: req_data = { "protocol": MAXFRAME_DEFAULT_PROTOCOL, "dag_id": dag_id, "output_format": self._output_format, } res = self._put_task_info(MAXFRAME_TASK_CANCEL_DAG_METHOD, req_data) return self._deserial_task_info_result(res, DagInfo) def decref(self, tileable_keys: List[str]) -> None: req_data = { "tileable_keys": ",".join(tileable_keys), } self._put_task_info(MAXFRAME_TASK_DECREF_METHOD, req_data) def get_logview_address(self, dag_id=None, hours=None) -> Optional[str]: """ Generate logview address Parameters ---------- dag_id: id of dag for which dag logview detail page to access hours: hours of the logview address auth limit Returns ------- Logview address """ hours = hours or options.session.logview_hours # notice: maxframe can't reuse subQuery else will conflict with mcqa when fetch resource data, # added dagId for maxframe so logview backend will return maxframe data format if # instance and dagId is provided. dag_suffix = f"&dagId={dag_id}" if dag_id else "" return self._instance.get_logview_address(hours) + dag_suffix class MaxFrameTaskSession(MaxFrameSession): schemes = [ODPS_SESSION_INSECURE_SCHEME, ODPS_SESSION_SECURE_SCHEME] _caller: MaxFrameInstanceCaller @classmethod def _create_caller( cls, odps_entry: ODPS, address: str, priority: Optional[int] = None, project: Optional[str] = None, running_cluster: Optional[str] = None, **kwargs, ) -> MaxFrameServiceCaller: return MaxFrameInstanceCaller( odps_entry, priority=priority, running_cluster=running_cluster, project=project, **kwargs, ) @property def closed(self) -> bool: if super().closed: return True if not self._caller or not self._caller.instance: # session not initialized yet return False return self._caller.instance.is_terminated() def register_session_schemes(overwrite: bool = False): MaxFrameTaskSession.register_schemes(overwrite=overwrite)