# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import enum
import uuid
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar

import pandas as pd

from .core import OutputType, TileableGraph
from .core.graph.entity import SerializableGraph
from .lib import wrapped_pickle as pickle
from .lib.tblib import pickling_support
from .serialization import PickleContainer, RemoteException, pickle_buffers
from .serialization.serializables import (
    AnyField,
    BoolField,
    BytesField,
    DictField,
    EnumField,
    FieldTypes,
    Float64Field,
    Int32Field,
    ListField,
    ReferenceField,
    Serializable,
    SeriesField,
    StringField,
)

pickling_support.install()

BodyType = TypeVar("BodyType", bound="Serializable")


class JsonSerializable(Serializable):
    _ignore_non_existing_keys = True

    @classmethod
    def from_json(cls, serialized: dict) -> "JsonSerializable":
        raise NotImplementedError

    def to_json(self) -> dict:
        raise NotImplementedError


class ProtocolBody(Generic[BodyType], Serializable):
    request_id: bytes = BytesField(
        "request_id", default_factory=lambda: uuid.uuid4().bytes
    )
    body: BodyType = AnyField("body", default=None)


class DagStatus(enum.Enum):
    PREPARING = 0
    RUNNING = 1
    SUCCEEDED = 2
    FAILED = 3
    CANCELLING = 4
    CANCELLED = 5

    def is_terminated(self):
        return self in (DagStatus.CANCELLED, DagStatus.SUCCEEDED, DagStatus.FAILED)


class DimensionIndex(Serializable):
    is_slice: bool = BoolField("is_slice", default=None)
    is_int_index: bool = BoolField("is_int_index", default=None)
    data: List = ListField("data", default=None)


class ResultType(enum.Enum):
    NULL = 0
    ODPS_TABLE = 1
    ODPS_VOLUME = 2


class DataSerializeType(enum.Enum):
    PICKLE = 0


_result_type_to_info_cls: Dict[ResultType, Type["ResultInfo"]] = dict()


class ResultInfo(JsonSerializable):
    _result_type = ResultType.NULL

    result_type: Optional[ResultType] = EnumField(
        "result_type", ResultType, default=ResultType.NULL
    )
    slices: Optional[List[DimensionIndex]] = ListField(
        "slices", FieldTypes.reference, default=None
    )

    @classmethod
    def _cls_from_result_type(cls, result_type: ResultType):
        if not _result_type_to_info_cls:
            for dest_cls in globals().values():
                if isinstance(dest_cls, type) and issubclass(dest_cls, ResultInfo):
                    _result_type_to_info_cls[dest_cls._result_type] = dest_cls
        return _result_type_to_info_cls[result_type]

    @classmethod
    def _json_to_kwargs(cls, serialized: dict) -> dict:
        # todo retrieve slices from json once implemented
        kw = serialized.copy()
        kw["result_type"] = ResultType(kw["result_type"])
        return kw

    @classmethod
    def from_json(cls, serialized: dict) -> "ResultInfo":
        res_type = ResultType(serialized["result_type"])
        res_cls = cls._cls_from_result_type(res_type)
        return res_cls(**res_cls._json_to_kwargs(serialized))

    def to_json(self) -> dict:
        # todo convert slices to json once implemented
        return {"result_type": self.result_type.value if self.result_type else None}


ResultInfoType = TypeVar("ResultInfoType", bound=ResultInfo)


class ODPSTableResultInfo(ResultInfo):
    _result_type = ResultType.ODPS_TABLE

    full_table_name: str = StringField("full_table_name", default=None)
    partition_specs: Optional[List[str]] = ListField(
        "partition_specs", FieldTypes.string, default=None
    )
    table_meta: Optional["DataFrameTableMeta"] = ReferenceField(
        "table_meta", default=None
    )

    def __init__(self, result_type: ResultType = None, **kw):
        result_type = result_type or ResultType.ODPS_TABLE
        super().__init__(result_type=result_type, **kw)

    def to_json(self) -> dict:
        ret = super().to_json()
        ret["full_table_name"] = self.full_table_name
        if self.partition_specs:
            ret["partition_specs"] = self.partition_specs
        if self.table_meta:
            ret["table_meta"] = self.table_meta.to_json()
        return ret

    @classmethod
    def _json_to_kwargs(cls, serialized: dict) -> dict:
        kw = super()._json_to_kwargs(serialized)
        if "table_meta" in kw:
            kw["table_meta"] = DataFrameTableMeta.from_json(kw["table_meta"])
        return kw


class ODPSVolumeResultInfo(ResultInfo):
    _result_type = ResultType.ODPS_VOLUME

    volume_name: str = StringField("volume_name", default=None)
    volume_path: str = StringField("volume_path", default=None)

    def __init__(self, result_type: ResultType = None, **kw):
        result_type = result_type or ResultType.ODPS_VOLUME
        super().__init__(result_type=result_type, **kw)

    def to_json(self) -> dict:
        ret = super().to_json()
        ret["volume_name"] = self.volume_name
        ret["volume_path"] = self.volume_path
        return ret


class ErrorSource(enum.Enum):
    PYTHON = 0


class ErrorInfo(JsonSerializable):
    error_messages: Optional[List[str]] = ListField("error_messages", FieldTypes.string)
    error_tracebacks: Optional[List[List[str]]] = ListField(
        "error_tracebacks", FieldTypes.list
    )
    raw_error_source: ErrorSource = EnumField(
        "raw_error_source", ErrorSource, FieldTypes.int8, default=None
    )
    raw_error_data: Optional[Exception] = AnyField("raw_error_data", default=None)

    @classmethod
    def from_exception(cls, exc: Exception):
        remote_exc = RemoteException.from_exception(exc)
        messages, tracebacks = remote_exc.messages, remote_exc.tracebacks
        return cls(messages, tracebacks, ErrorSource.PYTHON, exc)

    def reraise(self):
        if (
            self.raw_error_source == ErrorSource.PYTHON
            and self.raw_error_data is not None
        ):
            raise self.raw_error_data
        raise RemoteException(self.error_messages, self.error_tracebacks, [])

    @classmethod
    def from_json(cls, serialized: dict) -> "ErrorInfo":
        kw = serialized.copy()
        if kw.get("raw_error_source") is not None:
            kw["raw_error_source"] = ErrorSource(serialized["raw_error_source"])
        else:
            kw["raw_error_source"] = None

        if kw.get("raw_error_data"):
            bufs = [base64.b64decode(s) for s in kw["raw_error_data"]]
            try:
                kw["raw_error_data"] = pickle.loads(bufs[0], buffers=bufs[1:])
            except:
                # both error source and data shall be None to make sure
                # RemoteException is raised.
                kw["raw_error_source"] = kw["raw_error_data"] = None
        return cls(**kw)

    def to_json(self) -> dict:
        ret = {
            "error_messages": self.error_messages,
            "error_tracebacks": self.error_tracebacks,
            "raw_error_source": self.raw_error_source.value,
        }
        err_data_bufs = None
        if isinstance(self.raw_error_data, (PickleContainer, RemoteException)):
            err_data_bufs = self.raw_error_data.get_buffers()
        elif isinstance(self.raw_error_data, BaseException):
            try:
                err_data_bufs = pickle_buffers(self.raw_error_data)
            except:
                err_data_bufs = None
                ret["raw_error_source"] = None

        if err_data_bufs:
            ret["raw_error_data"] = [
                base64.b64encode(s).decode() for s in err_data_bufs
            ]
        return ret


class DagInfo(JsonSerializable):
    session_id: str = StringField("session_id", default=None)
    dag_id: str = StringField("dag_id", default=None)
    status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
    progress: float = Float64Field("progress", default=None)
    tileable_to_result_infos: Dict[str, ResultInfo] = DictField(
        "tileable_to_result_infos",
        FieldTypes.string,
        FieldTypes.reference,
        default_factory=dict,
    )
    error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)
    start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
    end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
    subdag_infos: Dict[str, "SubDagInfo"] = DictField(
        "subdag_infos",
        key_type=FieldTypes.string,
        value_type=FieldTypes.reference,
        default_factory=dict,
    )

    @classmethod
    def from_json(cls, serialized: dict) -> Optional["DagInfo"]:
        if serialized is None:
            return None
        kw = serialized.copy()
        kw["status"] = DagStatus(kw["status"])
        if kw.get("tileable_to_result_infos"):
            kw["tileable_to_result_infos"] = {
                k: ResultInfo.from_json(s)
                for k, s in kw["tileable_to_result_infos"].items()
            }
        if kw.get("error_info"):
            kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
        if kw.get("subdag_infos"):
            kw["subdag_infos"] = {
                k: SubDagInfo.from_json(v) for k, v in kw["subdag_infos"].items()
            }
        return DagInfo(**kw)

    def to_json(self) -> dict:
        ret = {
            "session_id": self.session_id,
            "dag_id": self.dag_id,
            "status": self.status.value,
            "progress": self.progress,
            "start_timestamp": self.start_timestamp,
            "end_timestamp": self.end_timestamp,
        }
        ret = {k: v for k, v in ret.items() if v is not None}
        if self.tileable_to_result_infos:
            ret["tileable_to_result_infos"] = {
                k: v.to_json() for k, v in self.tileable_to_result_infos.items()
            }
        if self.error_info:
            ret["error_info"] = self.error_info.to_json()
        if self.subdag_infos:
            ret["subdag_infos"] = {k: v.to_json() for k, v in self.subdag_infos.items()}
        return ret


class CreateSessionRequest(Serializable):
    settings: Dict[str, Any] = DictField("settings", default=None)


class SessionInfo(JsonSerializable):
    session_id: str = StringField("session_id")
    settings: Dict[str, Any] = DictField(
        "settings", key_type=FieldTypes.string, default=None
    )
    start_timestamp: float = Float64Field("start_timestamp", default=None)
    idle_timestamp: float = Float64Field("idle_timestamp", default=None)
    dag_infos: Dict[str, Optional[DagInfo]] = DictField(
        "dag_infos",
        key_type=FieldTypes.string,
        value_type=FieldTypes.reference,
        default=None,
    )
    error_info: Optional[ErrorInfo] = ReferenceField("error_info", default=None)

    @classmethod
    def from_json(cls, serialized: dict) -> Optional["SessionInfo"]:
        if serialized is None:
            return None
        kw = serialized.copy()
        if kw.get("dag_infos"):
            kw["dag_infos"] = {
                k: DagInfo.from_json(v) for k, v in kw["dag_infos"].items()
            }
        if kw.get("error_info"):
            kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
        return SessionInfo(**kw)

    def to_json(self) -> dict:
        ret = {
            "session_id": self.session_id,
            "settings": self.settings,
            "start_timestamp": self.start_timestamp,
            "idle_timestamp": self.idle_timestamp,
        }
        if self.dag_infos:
            ret["dag_infos"] = {
                k: v.to_json() if v is not None else None
                for k, v in self.dag_infos.items()
            }
        if self.error_info:
            ret["error_info"] = self.error_info.to_json()
        return ret


class ExecuteDagRequest(Serializable):
    session_id: str = StringField("session_id")
    dag: TileableGraph = ReferenceField(
        "dag",
        on_serialize=SerializableGraph.from_graph,
        on_deserialize=lambda x: x.to_graph(),
        default=None,
    )
    managed_input_infos: Dict[str, ResultInfo] = DictField(
        "managed_input_infos",
        key_type=FieldTypes.string,
        value_type=FieldTypes.reference,
        default=None,
    )
    new_settings: Dict[str, Any] = DictField(
        "new_settings",
        key_type=FieldTypes.string,
        default=None,
    )


class SubDagSubmitInstanceInfo(JsonSerializable):
    submit_reason: str = StringField("submit_reason")
    instance_id: str = StringField("instance_id")
    subquery_id: Optional[int] = Int32Field("subquery_id", default=None)

    @classmethod
    def from_json(cls, serialized: dict) -> "SubDagSubmitInstanceInfo":
        return SubDagSubmitInstanceInfo(**serialized)

    def to_json(self) -> dict:
        ret = {
            "submit_reason": self.submit_reason,
            "instance_id": self.instance_id,
            "subquery_id": self.subquery_id,
        }
        return ret


class SubDagInfo(JsonSerializable):
    subdag_id: str = StringField("subdag_id")
    status: DagStatus = EnumField("status", DagStatus, FieldTypes.int8, default=None)
    progress: float = Float64Field("progress", default=None)
    error_info: Optional[ErrorInfo] = ReferenceField(
        "error_info", reference_type=ErrorInfo, default=None
    )
    tileable_to_result_infos: Dict[str, ResultInfo] = DictField(
        "tileable_to_result_infos",
        FieldTypes.string,
        FieldTypes.reference,
        default_factory=dict,
    )
    start_timestamp: Optional[float] = Float64Field("start_timestamp", default=None)
    end_timestamp: Optional[float] = Float64Field("end_timestamp", default=None)
    submit_instances: List[SubDagSubmitInstanceInfo] = ListField(
        "submit_instances",
        FieldTypes.reference,
        default_factory=list,
    )

    @classmethod
    def from_json(cls, serialized: dict) -> "SubDagInfo":
        kw = serialized.copy()
        kw["status"] = DagStatus(kw["status"])
        if kw.get("tileable_to_result_infos"):
            kw["tileable_to_result_infos"] = {
                k: ResultInfo.from_json(s)
                for k, s in kw["tileable_to_result_infos"].items()
            }
        if kw.get("error_info"):
            kw["error_info"] = ErrorInfo.from_json(kw["error_info"])
        if kw.get("submit_instances"):
            kw["submit_instances"] = [
                SubDagSubmitInstanceInfo.from_json(s) for s in kw["submit_instances"]
            ]
        return SubDagInfo(**kw)

    def to_json(self) -> dict:
        ret = {
            "subdag_id": self.subdag_id,
            "status": self.status.value,
            "progress": self.progress,
            "start_timestamp": self.start_timestamp,
            "end_timestamp": self.end_timestamp,
        }
        if self.error_info:
            ret["error_info"] = self.error_info.to_json()
        if self.tileable_to_result_infos:
            ret["tileable_to_result_infos"] = {
                k: v.to_json() for k, v in self.tileable_to_result_infos.items()
            }
        if self.submit_instances:
            ret["submit_instances"] = [i.to_json() for i in self.submit_instances]
        return ret


class ExecuteSubDagRequest(Serializable):
    subdag_id: str = StringField("subdag_id")
    dag: TileableGraph = ReferenceField(
        "dag",
        on_serialize=SerializableGraph.from_graph,
        on_deserialize=lambda x: x.to_graph(),
        default=None,
    )
    settings: Dict[str, Any] = DictField("settings", default=None)


class DecrefRequest(Serializable):
    keys: List[str] = ListField("keys", FieldTypes.string, default=None)


class DataFrameTableMeta(JsonSerializable):
    __slots__ = "_pd_column_names", "_pd_index_level_names"

    table_name: Optional[str] = StringField("table_name", default=None)
    type: OutputType = EnumField("type", OutputType, FieldTypes.int8, default=None)
    table_column_names: List[str] = ListField(
        "table_column_names", FieldTypes.string, default=None
    )
    table_index_column_names: List[str] = ListField(
        "table_index_column_names", FieldTypes.string, default=None
    )
    pd_column_dtypes: pd.Series = SeriesField("pd_column_dtypes", default=None)
    pd_column_level_names: List[Any] = ListField("pd_column_level_names", default=None)
    pd_index_dtypes: pd.Series = SeriesField("pd_index_dtypes", default=None)

    @property
    def pd_column_names(self) -> List[Any]:
        try:
            return self._pd_column_names
        except AttributeError:
            self._pd_column_names = self.pd_column_dtypes.index.tolist()
            return self._pd_column_names

    @property
    def pd_index_level_names(self) -> List[Any]:
        try:
            return self._pd_index_level_names
        except AttributeError:
            self._pd_index_level_names = self.pd_index_dtypes.index.tolist()
            return self._pd_index_level_names

    def __eq__(self, other: "DataFrameTableMeta") -> bool:
        if not isinstance(other, type(self)):
            return False
        for k in self._FIELDS:
            v = getattr(self, k, None)
            is_same = v == getattr(other, k, None)
            if callable(getattr(is_same, "all", None)):
                is_same = is_same.all()
            if not is_same:
                return False
        return True

    def to_json(self) -> dict:
        b64_pk = lambda x: base64.b64encode(pickle.dumps(x)).decode()
        ret = {
            "table_name": self.table_name,
            "type": self.type.value,
            "table_column_names": self.table_column_names,
            "table_index_column_names": self.table_index_column_names,
            "pd_column_dtypes": b64_pk(self.pd_column_dtypes),
            "pd_column_level_names": b64_pk(self.pd_column_level_names),
            "pd_index_dtypes": b64_pk(self.pd_index_dtypes),
        }
        return ret

    @classmethod
    def from_json(cls, serialized: dict) -> "DataFrameTableMeta":
        b64_upk = lambda x: pickle.loads(base64.b64decode(x))
        serialized.update(
            {
                "type": OutputType(serialized["type"]),
                "pd_column_dtypes": b64_upk(serialized["pd_column_dtypes"]),
                "pd_column_level_names": b64_upk(serialized["pd_column_level_names"]),
                "pd_index_dtypes": b64_upk(serialized["pd_index_dtypes"]),
            }
        )
        return DataFrameTableMeta(**serialized)
