core/maxframe/protocol.py (441 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import enum
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
import pandas as pd
from .core import OutputType, TileableGraph
from .core.graph.entity import SerializableGraph
from .lib import wrapped_pickle as pickle
from .lib.tblib import pickling_support
from .serialization import PickleContainer, RemoteException, pickle_buffers
from .serialization.serializables import (
AnyField,
BoolField,
BytesField,
DictField,
EnumField,
FieldTypes,
Float64Field,
Int32Field,
ListField,
ReferenceField,
Serializable,
SeriesField,
StringField,
)
pickling_support.install()
BodyType = TypeVar("BodyType", bound="Serializable")
class JsonSerializable(Serializable):
_ignore_non_existing_keys = True
@classmethod
def from_json(cls, serialized: dict) -> "JsonSerializable":
raise NotImplementedError
def to_json(self) -> dict:
raise NotImplementedError
class ProtocolBody(Generic[BodyType], Serializable):
request_id: bytes = BytesField(
"request_id", default_factory=lambda: uuid.uuid4().bytes
)
body: BodyType = AnyField("body", default=None)
class DagStatus(enum.Enum):
PREPARING = 0
RUNNING = 1
SUCCEEDED = 2
FAILED = 3
CANCELLING = 4
CANCELLED = 5
def is_terminated(self):
return self in (DagStatus.CANCELLED, DagStatus.SUCCEEDED, DagStatus.FAILED)
class DimensionIndex(Serializable):
is_slice: bool = BoolField("is_slice", default=None)
is_int_index: bool = BoolField("is_int_index", default=None)
data: List = ListField("data", default=None)
class ResultType(enum.Enum):
NULL = 0
ODPS_TABLE = 1
ODPS_VOLUME = 2
class DataSerializeType(enum.Enum):
PICKLE = 0
_result_type_to_info_cls: Dict[ResultType, Type["ResultInfo"]] = dict()
class ResultInfo(JsonSerializable):
_result_type = ResultType.NULL
result_type: Optional[ResultType] = EnumField(
"result_type", ResultType, default=ResultType.NULL
)
slices: Optional[List[DimensionIndex]] = ListField(
"slices", FieldTypes.reference, default=None
)
@classmethod
def _cls_from_result_type(cls, result_type: ResultType):
if not _result_type_to_info_cls:
for dest_cls in globals().values():
if isinstance(dest_cls, type) and issubclass(dest_cls, ResultInfo):
_result_type_to_info_cls[dest_cls._result_type] = dest_cls
return _result_type_to_info_cls[result_type]
@classmethod
def _json_to_kwargs(cls, serialized: dict) -> dict:
# todo retrieve slices from json once implemented
kw = serialized.copy()
kw["result_type"] = ResultType(kw["result_type"])
return kw
@classmethod
def from_json(cls, serialized: dict) -> "ResultInfo":
res_type = ResultType(serialized["result_type"])
res_cls = cls._cls_from_result_type(res_type)
return res_cls(**res_cls._json_to_kwargs(serialized))
def to_json(self) -> dict:
# todo convert slices to json once implemented
return {"result_type": self.result_type.value if self.result_type else None}
ResultInfoType = TypeVar("ResultInfoType", bound=ResultInfo)
class ODPSTableResultInfo(ResultInfo):
_result_type = ResultType.ODPS_TABLE
full_table_name: str = StringField("full_table_name", default=None)
partition_specs: Optional[List[str]] = ListField(
"partition_specs", FieldTypes.string, default=None
)
table_meta: Optional["DataFrameTableMeta"] = ReferenceField(
"table_meta", default=None
)
def __init__(self, result_type: ResultType = None, **kw):
result_type = result_type or ResultType.ODPS_TABLE
super().__init__(result_type=result_type, **kw)
def to_json(self) -> dict:
ret = super().to_json()
ret["full_table_name"] = self.full_table_name
if self.partition_specs:
ret["partition_specs"] = self.partition_specs
if self.table_meta:
ret["table_meta"] = self.table_meta.to_json()
return ret
@classmethod
def _json_to_kwargs(cls, serialized: dict) -> dict:
kw = super()._json_to_kwargs(serialized)
if "table_meta" in kw:
kw["table_meta"] = DataFrameTableMeta.from_json(kw["table_meta"])
return kw
class ODPSVolumeResultInfo(ResultInfo):
_result_type = ResultType.ODPS_VOLUME
volume_name: str = StringField("volume_name", default=None)
volume_path: str = StringField("volume_path", default=None)
def __init__(self, result_type: ResultType = None, **kw):
result_type = result_type or ResultType.ODPS_VOLUME
super().__init__(result_type=result_type, **kw)
def to_json(self) -> dict:
ret = super().to_json()
ret["volume_name"] = self.volume_name
ret["volume_path"] = self.volume_path
return ret
class ErrorSource(enum.Enum):
PYTHON = 0
class ErrorInfo(JsonSerializable):
error_messages: Optional[List[str]] = ListField("error_messages", FieldTypes.string)
error_tracebacks: Optional[List[List[str]]] = ListField(
"error_tracebacks", FieldTypes.list
)
raw_error_source: ErrorSource = EnumField(
"raw_error_source", ErrorSource, FieldTypes.int8, default=None
)
raw_error_data: Optional[Exception] = AnyField("raw_error_data", default=None)
@classmethod
def from_exception(cls, exc: Exception):
remote_exc = RemoteException.from_exception(exc)
messages, tracebacks = remote_exc.messages, remote_exc.tracebacks
return cls(messages, tracebacks, ErrorSource.PYTHON, exc)
def reraise(self):
if (
self.raw_error_source == ErrorSource.PYTHON
and self.raw_error_data is not None
):
raise self.raw_error_data
raise RemoteException(self.error_messages, self.error_tracebacks, [])
@classmethod
def from_json(cls, serialized: dict) -> "ErrorInfo":
kw = serialized.copy()
if kw.get("raw_error_source") is not None:
kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
else:
kw["raw_error_source"] = None
if kw.get("raw_error_data"):
bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
try:
kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
except:
# both error source and data shall be None to make sure
# RemoteException is raised.
kw["raw_error_source"] = kw["raw_error_data"] = None
return cls(**kw)
def to_json(self) -> dict:
ret = {
"error_messages": self.error_messages,
"error_tracebacks": self.error_tracebacks,
"raw_error_source": self.raw_error_source.value,
}
err_data_bufs = None
if isinstance(self.raw_error_data, (PickleContainer, RemoteException)):
err_data_bufs = self.raw_error_data.get_buffers()
elif isinstance(self.raw_error_data, BaseException):
try:
err_data_bufs = pickle_buffers(self.raw_error_data)
except:
err_data_bufs = None
ret["raw_error_source"] = None
if err_data_bufs:
ret["raw_error_data"] = [
base64.b64encode(s).decode() for s in err_data_bufs
]
return ret
class DagInfo(JsonSerializable):
session_id: str = StringField("session_id", default=None)
dag_id: str = StringField("dag_id", default=None)
status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
progress: float = Float64Field("progress", default=None)
tileable_to_result_infos: Dict[str, ResultInfo] = DictField(
"tileable_to_result_infos",
FieldTypes.string,
FieldTypes.reference,
default_factory=dict,
)
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
subdag_infos: Dict[str, "SubDagInfo"] = DictField(
"subdag_infos",
key_type=FieldTypes.string,
value_type=FieldTypes.reference,
default_factory=dict,
)
@classmethod
def from_json(cls, serialized: dict) -> Optional["DagInfo"]:
if serialized is None:
return None
kw = serialized.copy()
kw["status"] = DagStatus(kw["status"])
if kw.get("tileable_to_result_infos"):
kw["tileable_to_result_infos"] = {
k: ResultInfo.from_json(s)
for k, s in kw["tileable_to_result_infos"].items()
}
if kw.get("error_info"):
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
if kw.get("subdag_infos"):
kw["subdag_infos"] = {
k: SubDagInfo.from_json(v) for k, v in kw["subdag_infos"].items()
}
return DagInfo(**kw)
def to_json(self) -> dict:
ret = {
"session_id": self.session_id,
"dag_id": self.dag_id,
"status": self.status.value,
"progress": self.progress,
"start_timestamp": self.start_timestamp,
"end_timestamp": self.end_timestamp,
}
ret = {k: v for k, v in ret.items() if v is not None}
if self.tileable_to_result_infos:
ret["tileable_to_result_infos"] = {
k: v.to_json() for k, v in self.tileable_to_result_infos.items()
}
if self.error_info:
ret["error_info"] = self.error_info.to_json()
if self.subdag_infos:
ret["subdag_infos"] = {k: v.to_json() for k, v in self.subdag_infos.items()}
return ret
class CreateSessionRequest(Serializable):
settings: Dict[str, Any] = DictField("settings", default=None)
class SessionInfo(JsonSerializable):
session_id: str = StringField("session_id")
settings: Dict[str, Any] = DictField(
"settings", key_type=FieldTypes.string, default=None
)
start_timestamp: float = Float64Field("start_timestamp", default=None)
idle_timestamp: float = Float64Field("idle_timestamp", default=None)
dag_infos: Dict[str, Optional[DagInfo]] = DictField(
"dag_infos",
key_type=FieldTypes.string,
value_type=FieldTypes.reference,
default=None,
)
error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
@classmethod
def from_json(cls, serialized: dict) -> Optional["SessionInfo"]:
if serialized is None:
return None
kw = serialized.copy()
if kw.get("dag_infos"):
kw["dag_infos"] = {
k: DagInfo.from_json(v) for k, v in kw["dag_infos"].items()
}
if kw.get("error_info"):
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
return SessionInfo(**kw)
def to_json(self) -> dict:
ret = {
"session_id": self.session_id,
"settings": self.settings,
"start_timestamp": self.start_timestamp,
"idle_timestamp": self.idle_timestamp,
}
if self.dag_infos:
ret["dag_infos"] = {
k: v.to_json() if v is not None else None
for k, v in self.dag_infos.items()
}
if self.error_info:
ret["error_info"] = self.error_info.to_json()
return ret
class ExecuteDagRequest(Serializable):
session_id: str = StringField("session_id")
dag: TileableGraph = ReferenceField(
"dag",
on_serialize=SerializableGraph.from_graph,
on_deserialize=lambda x: x.to_graph(),
default=None,
)
managed_input_infos: Dict[str, ResultInfo] = DictField(
"managed_input_infos",
key_type=FieldTypes.string,
value_type=FieldTypes.reference,
default=None,
)
new_settings: Dict[str, Any] = DictField(
"new_settings",
key_type=FieldTypes.string,
default=None,
)
class SubDagSubmitInstanceInfo(JsonSerializable):
submit_reason: str = StringField("submit_reason")
instance_id: str = StringField("instance_id")
subquery_id: Optional[int] = Int32Field("subquery_id", default=None)
@classmethod
def from_json(cls, serialized: dict) -> "SubDagSubmitInstanceInfo":
return SubDagSubmitInstanceInfo(**serialized)
def to_json(self) -> dict:
ret = {
"submit_reason": self.submit_reason,
"instance_id": self.instance_id,
"subquery_id": self.subquery_id,
}
return ret
class SubDagInfo(JsonSerializable):
subdag_id: str = StringField("subdag_id")
status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
progress: float = Float64Field("progress", default=None)
error_info: Optional[ErrorInfo] = ReferenceField(
"error_info", reference_type=ErrorInfo, default=None
)
tileable_to_result_infos: Dict[str, ResultInfo] = DictField(
"tileable_to_result_infos",
FieldTypes.string,
FieldTypes.reference,
default_factory=dict,
)
start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
submit_instances: List[SubDagSubmitInstanceInfo] = ListField(
"submit_instances",
FieldTypes.reference,
default_factory=list,
)
@classmethod
def from_json(cls, serialized: dict) -> "SubDagInfo":
kw = serialized.copy()
kw["status"] = DagStatus(kw["status"])
if kw.get("tileable_to_result_infos"):
kw["tileable_to_result_infos"] = {
k: ResultInfo.from_json(s)
for k, s in kw["tileable_to_result_infos"].items()
}
if kw.get("error_info"):
kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
if kw.get("submit_instances"):
kw["submit_instances"] = [
SubDagSubmitInstanceInfo.from_json(s) for s in kw["submit_instances"]
]
return SubDagInfo(**kw)
def to_json(self) -> dict:
ret = {
"subdag_id": self.subdag_id,
"status": self.status.value,
"progress": self.progress,
"start_timestamp": self.start_timestamp,
"end_timestamp": self.end_timestamp,
}
if self.error_info:
ret["error_info"] = self.error_info.to_json()
if self.tileable_to_result_infos:
ret["tileable_to_result_infos"] = {
k: v.to_json() for k, v in self.tileable_to_result_infos.items()
}
if self.submit_instances:
ret["submit_instances"] = [i.to_json() for i in self.submit_instances]
return ret
class ExecuteSubDagRequest(Serializable):
subdag_id: str = StringField("subdag_id")
dag: TileableGraph = ReferenceField(
"dag",
on_serialize=SerializableGraph.from_graph,
on_deserialize=lambda x: x.to_graph(),
default=None,
)
settings: Dict[str, Any] = DictField("settings", default=None)
class DecrefRequest(Serializable):
keys: List[str] = ListField("keys", FieldTypes.string, default=None)
class DataFrameTableMeta(JsonSerializable):
__slots__ = "_pd_column_names", "_pd_index_level_names"
table_name: Optional[str] = StringField("table_name", default=None)
type: OutputType = EnumField("type", OutputType, FieldTypes.int8, default=None)
table_column_names: List[str] = ListField(
"table_column_names", FieldTypes.string, default=None
)
table_index_column_names: List[str] = ListField(
"table_index_column_names", FieldTypes.string, default=None
)
pd_column_dtypes: pd.Series = SeriesField("pd_column_dtypes", default=None)
pd_column_level_names: List[Any] = ListField("pd_column_level_names", default=None)
pd_index_dtypes: pd.Series = SeriesField("pd_index_dtypes", default=None)
@property
def pd_column_names(self) -> List[Any]:
try:
return self._pd_column_names
except AttributeError:
self._pd_column_names = self.pd_column_dtypes.index.tolist()
return self._pd_column_names
@property
def pd_index_level_names(self) -> List[Any]:
try:
return self._pd_index_level_names
except AttributeError:
self._pd_index_level_names = self.pd_index_dtypes.index.tolist()
return self._pd_index_level_names
def __eq__(self, other: "DataFrameTableMeta") -> bool:
if not isinstance(other, type(self)):
return False
for k in self._FIELDS:
v = getattr(self, k, None)
is_same = v == getattr(other, k, None)
if callable(getattr(is_same, "all", None)):
is_same = is_same.all()
if not is_same:
return False
return True
def to_json(self) -> dict:
b64_pk = lambda x: base64.b64encode(pickle.dumps(x)).decode()
ret = {
"table_name": self.table_name,
"type": self.type.value,
"table_column_names": self.table_column_names,
"table_index_column_names": self.table_index_column_names,
"pd_column_dtypes": b64_pk(self.pd_column_dtypes),
"pd_column_level_names": b64_pk(self.pd_column_level_names),
"pd_index_dtypes": b64_pk(self.pd_index_dtypes),
}
return ret
@classmethod
def from_json(cls, serialized: dict) -> "DataFrameTableMeta":
b64_upk = lambda x: pickle.loads(base64.b64decode(x))
serialized.update(
{
"type": OutputType(serialized["type"]),
"pd_column_dtypes": b64_upk(serialized["pd_column_dtypes"]),
"pd_column_level_names": b64_upk(serialized["pd_column_level_names"]),
"pd_index_dtypes": b64_upk(serialized["pd_index_dtypes"]),
}
)
return DataFrameTableMeta(**serialized)