# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import asyncio
import copy
import logging
import time
import weakref
from numbers import Integral
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from urllib.parse import urlparse

import numpy as np
import pandas as pd
from odps import ODPS
from odps import options as odps_options
from odps.console import in_ipython_frontend

from maxframe.config import options
from maxframe.core import Entity, TileableGraph, build_fetch, enter_mode
from maxframe.core.operator import Fetch
from maxframe.dataframe import read_odps_table
from maxframe.dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
from maxframe.dataframe.datasource import PandasDataSourceOperator
from maxframe.dataframe.datasource.read_odps_table import DataFrameReadODPSTable
from maxframe.errors import (
    MaxFrameError,
    NoTaskServerResponseError,
    SessionAlreadyClosedError,
)
from maxframe.io.objects import get_object_io_handler
from maxframe.io.odpsio import (
    ODPSTableIO,
    ODPSVolumeWriter,
    pandas_to_arrow,
    pandas_to_odps_schema,
)
from maxframe.protocol import (
    DagInfo,
    DagStatus,
    ODPSTableResultInfo,
    ODPSVolumeResultInfo,
    ResultInfo,
    SessionInfo,
)
from maxframe.session import (
    AbstractSession,
    ExecutionInfo,
    IsolatedAsyncSession,
    Profiling,
    Progress,
)
from maxframe.tensor.datasource import ArrayDataSource
from maxframe.typing_ import TileableType
from maxframe.utils import (
    ToThreadMixin,
    build_session_volume_name,
    build_temp_table_name,
    get_default_table_properties,
    str_to_bool,
    sync_pyodps_options,
)

from ..clients.framedriver import FrameDriverClient
from ..fetcher import get_fetcher_cls
from .consts import RESTFUL_SESSION_INSECURE_SCHEME, RESTFUL_SESSION_SECURE_SCHEME
from .graph import gen_submit_tileable_graph

logger = logging.getLogger(__name__)


class MaxFrameServiceCaller(metaclass=abc.ABCMeta):
    def get_settings_to_upload(self) -> Dict[str, Any]:
        sql_settings = (odps_options.sql.settings or {}).copy()
        sql_settings.update(options.sql.settings or {})
        quota_name = options.session.quota_name or getattr(
            odps_options, "quota_name", None
        )
        quota_settings = {
            sql_settings.get("odps.task.wlm.quota", None),
            options.spe.task.settings.get("odps.task.wlm.quota", None),
            options.pythonpack.task.settings.get("odps.task.wlm.quota", None),
            quota_name,
        }.difference([None])
        if len(quota_settings) >= 2:
            raise ValueError(
                "Quota settings are conflicting: %s" % ", ".join(sorted(quota_settings))
            )
        elif len(quota_settings) == 1:
            quota_name = quota_settings.pop()
        lifecycle = options.session.table_lifecycle or odps_options.lifecycle
        temp_lifecycle = (
            options.session.temp_table_lifecycle or odps_options.temp_lifecycle
        )

        enable_schema = options.session.enable_schema
        default_schema = options.session.default_schema
        if hasattr(self, "_odps_entry"):
            default_schema = default_schema or self._odps_entry.schema

        # use flags in sql settings
        if sql_settings.get("odps.default.schema"):
            default_schema = sql_settings["odps.default.schema"]
        if str_to_bool(
            sql_settings.get("odps.namespace.schema") or "false"
        ) or str_to_bool(
            sql_settings.get("odps.sql.allow.namespace.schema") or "false"
        ):
            enable_schema = True

        mf_settings = dict(options.to_dict(remote_only=True).items())
        mf_settings["sql.settings"] = sql_settings
        mf_settings["session.table_lifecycle"] = lifecycle
        mf_settings["session.temp_table_lifecycle"] = temp_lifecycle
        mf_settings["session.quota_name"] = quota_name
        if enable_schema is not None:
            mf_settings["session.enable_schema"] = enable_schema
        if options.session.enable_high_availability is None:
            mf_settings["session.enable_high_availability"] = not in_ipython_frontend()
        mf_settings["session.default_schema"] = default_schema or "default"
        return mf_settings

    @abc.abstractmethod
    def create_session(self) -> SessionInfo:
        raise NotImplementedError

    @abc.abstractmethod
    def delete_session(self) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def submit_dag(
        self,
        dag: TileableGraph,
        managed_input_infos: Dict[str, ResultInfo],
        new_settings: Dict[str, Any] = None,
    ) -> DagInfo:
        raise NotImplementedError

    @abc.abstractmethod
    def get_dag_info(self, dag_id: str) -> DagInfo:
        raise NotImplementedError

    @abc.abstractmethod
    def cancel_dag(self, dag_id: str) -> DagInfo:
        raise NotImplementedError

    @abc.abstractmethod
    def decref(self, tileable_keys: List[str]) -> None:
        raise NotImplementedError

    def get_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
        return None


class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
    _odps_entry: Optional[ODPS]
    _tileable_to_infos: Mapping[TileableType, ResultInfo]

    @classmethod
    async def init(
        cls,
        address: str,
        session_id: Optional[str] = None,
        backend: str = None,
        odps_entry: Optional[ODPS] = None,
        timeout: Optional[float] = None,
        **kwargs,
    ) -> "AbstractSession":
        session_obj = cls(
            address, session_id, odps_entry=odps_entry, timeout=timeout, **kwargs
        )
        await session_obj._init(address)
        return session_obj

    def __init__(
        self,
        address: str,
        session_id: str,
        odps_entry: Optional[ODPS] = None,
        timeout: Optional[float] = None,
        **kwargs,
    ):
        super().__init__(address, session_id)
        self.timeout = timeout
        self._odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments()
        self._tileable_to_infos = weakref.WeakKeyDictionary()

        self._caller = self._create_caller(odps_entry, address, **kwargs)
        self._last_settings = None
        self._pull_interval = 1 if in_ipython_frontend() else 3
        self._replace_internal_host = kwargs.get("replace_internal_host", True)

    @classmethod
    def _create_caller(
        cls, odps_entry: ODPS, address: str, **kwargs
    ) -> MaxFrameServiceCaller:
        raise NotImplementedError

    async def _init(self, _address: str):
        session_info = await self.ensure_async_call(self._caller.create_session)
        self._last_settings = copy.deepcopy(self._caller.get_settings_to_upload())
        self._session_id = session_info.session_id
        await self._show_logview_address()

    def _upload_and_get_table_read_tileable(
        self, t: TileableType
    ) -> Optional[TileableType]:
        table_schema, table_meta = pandas_to_odps_schema(t, unknown_as_string=True)
        if self._odps_entry.exist_table(table_meta.table_name):
            self._odps_entry.delete_table(
                table_meta.table_name, hints=options.sql.settings
            )
        table_name = build_temp_table_name(self.session_id, t.key)
        table_obj = self._odps_entry.create_table(
            table_name,
            table_schema,
            lifecycle=options.session.temp_table_lifecycle,
            hints=options.sql.settings,
            if_not_exists=True,
            table_properties=options.session.temp_table_properties
            or get_default_table_properties(),
        )

        data = t.op.get_data()
        batch_size = options.session.upload_batch_size

        if len(data):
            table_client = ODPSTableIO(self._odps_entry)
            with table_client.open_writer(table_obj.full_table_name) as writer:
                for batch_start in range(0, len(data), batch_size):
                    if isinstance(data, pd.Index):
                        batch = data[batch_start : batch_start + batch_size]
                    else:
                        batch = data.iloc[batch_start : batch_start + batch_size]
                    arrow_batch, _ = pandas_to_arrow(batch)
                    writer.write(arrow_batch)

        read_tileable = read_odps_table(
            table_obj.full_table_name,
            columns=table_meta.table_column_names,
            index_col=table_meta.table_index_column_names,
            output_type=table_meta.type,
        )
        if isinstance(read_tileable, DATAFRAME_TYPE):
            if list(read_tileable.dtypes.index) != list(t.dtypes.index):
                read_tileable.columns = list(t.dtypes.index)
        elif isinstance(read_tileable, SERIES_TYPE):
            if read_tileable.name != t.name:
                read_tileable.name = t.name
        else:  # INDEX_TYPE
            if list(read_tileable.names) != list(t.names):
                read_tileable.rename(t.names, inplace=True)
        read_tileable._key = t.key
        read_tileable.params = t.params
        return read_tileable.data

    def _upload_and_get_vol_read_tileable(
        self, t: TileableType
    ) -> Optional[TileableType]:
        vol_name = build_session_volume_name(self.session_id)
        writer = ODPSVolumeWriter(
            self._odps_entry,
            vol_name,
            t.key,
            replace_internal_host=self._replace_internal_host,
        )
        io_handler = get_object_io_handler(t)
        io_handler().write_object(writer, t, t.op.data)
        return build_fetch(t).data

    def _upload_and_get_read_tileable(self, t: TileableType) -> Optional[TileableType]:
        if (
            not isinstance(t.op, (ArrayDataSource, PandasDataSourceOperator))
            or t.op.get_data() is None
            or t.inputs
        ):
            return None
        with sync_pyodps_options():
            if isinstance(t.op, PandasDataSourceOperator):
                return self._upload_and_get_table_read_tileable(t)
            else:
                return self._upload_and_get_vol_read_tileable(t)

    @enter_mode(kernel=True, build=True)
    def _scan_and_replace_local_sources(
        self, graph: TileableGraph
    ) -> Dict[TileableType, TileableType]:
        """Replaces Pandas data sources with temp table sources in the graph"""
        replacements = dict()
        for t in graph:
            replaced = self._upload_and_get_read_tileable(t)
            if replaced is None:
                continue
            replacements[t] = replaced

        for src, replaced in replacements.items():
            successors = list(graph.successors(src))
            graph.remove_node(src)
            graph.add_node(replaced)
            for pred in replaced.inputs or ():
                graph.add_node(pred)
                graph.add_edge(pred, replaced)

            for succ in successors:
                graph.add_edge(replaced, succ)
                succ.op._set_inputs([replacements.get(t, t) for t in succ.inputs])

        graph.results = [replacements.get(t, t) for t in graph.results]
        return replacements

    @enter_mode(kernel=True, build=True)
    def _get_input_infos(self, tileables: List[TileableType]) -> Dict[str, ResultInfo]:
        """Generate ResultInfo structs from generated temp tables"""
        vol_name = build_session_volume_name(self.session_id)

        infos = dict()
        for t in tileables:
            key = t.key
            if isinstance(t.op, DataFrameReadODPSTable):
                infos[key] = ODPSTableResultInfo(full_table_name=t.op.table_name)
            else:
                if isinstance(t.op, Fetch):
                    infos[key] = ODPSVolumeResultInfo(
                        volume_name=vol_name, volume_path=t.key
                    )
                elif t.inputs and isinstance(t.inputs[0].op, DataFrameReadODPSTable):
                    t = t.inputs[0]
                    infos[key] = ODPSTableResultInfo(full_table_name=t.op.table_name)
        return infos

    def _get_diff_settings(self) -> Dict[str, Any]:
        new_settings = self._caller.get_settings_to_upload()
        if not self._last_settings:  # pragma: no cover
            self._last_settings = copy.deepcopy(new_settings)
            return new_settings

        if self._last_settings.get("session.quota_name", None) != new_settings.get(
            "session.quota_name", None
        ):
            raise ValueError("Quota name cannot be changed after sessions are created")

        update = dict()
        for k in new_settings.keys():
            old_item = self._last_settings.get(k)
            new_item = new_settings.get(k)
            try:
                if old_item != new_item:
                    update[k] = new_item
            except:  # noqa: E722  # nosec  # pylint: disable=bare-except
                update[k] = new_item
        self._last_settings = copy.deepcopy(new_settings)
        return update

    async def execute(self, *tileables, **kwargs) -> ExecutionInfo:
        tileables = [
            tileable.data if isinstance(tileable, Entity) else tileable
            for tileable in tileables
        ]
        tileable_to_copied = dict()
        tileable_graph, to_execute_tileables = gen_submit_tileable_graph(
            self, tileables, tileable_to_copied
        )
        source_replacements = self._scan_and_replace_local_sources(tileable_graph)

        # we need to manage uploaded data sources with refcounting mechanism
        # as nodes in tileable_graph are copied, we need to use original nodes
        copied_to_tileable = {v: k for k, v in tileable_to_copied.items()}
        for replaced_src in source_replacements.keys():
            copied_to_tileable[replaced_src]._attach_session(self)

        replaced_infos = self._get_input_infos(list(source_replacements.values()))
        dag_info = await self.ensure_async_call(
            self._caller.submit_dag,
            tileable_graph,
            replaced_infos,
            self._get_diff_settings(),
        )

        await self._show_logview_address(dag_info.dag_id)

        progress = Progress()
        profiling = Profiling()
        aio_task = asyncio.create_task(
            self._run_in_background(dag_info, to_execute_tileables, progress)
        )
        return ExecutionInfo(
            aio_task,
            progress,
            profiling,
            asyncio.get_running_loop(),
            to_execute_tileables,
        )

    async def _run_in_background(
        self, dag_info: DagInfo, tileables: List, progress: Progress
    ):
        start_time = time.time()
        session_id = dag_info.session_id
        dag_id = dag_info.dag_id
        server_no_response_time = None
        with enter_mode(build=True, kernel=True):
            key_to_tileables = {t.key: t for t in tileables}
            timeout_val = 0.1
            try:
                while True:
                    elapsed_time = time.time() - start_time
                    next_timeout_val = min(timeout_val * 2, self._pull_interval)
                    timeout_val = (
                        min(self.timeout - elapsed_time, next_timeout_val)
                        if self.timeout
                        else next_timeout_val
                    )
                    if timeout_val <= 0:
                        raise TimeoutError("Running DAG timed out")

                    try:
                        dag_info: DagInfo = await self.ensure_async_call(
                            self._caller.get_dag_info, dag_id
                        )
                        server_no_response_time = None
                    except (NoTaskServerResponseError, SessionAlreadyClosedError) as ex:
                        # when we receive SessionAlreadyClosedError after NoTaskServerResponseError
                        #  is received, it is possible that task server is restarted and
                        #  SessionAlreadyClosedError might be flaky. Otherwise, the error
                        #  should be raised.
                        if (
                            isinstance(ex, SessionAlreadyClosedError)
                            and not server_no_response_time
                        ):
                            raise
                        server_no_response_time = server_no_response_time or time.time()
                        if (
                            time.time() - server_no_response_time
                            > options.client.task_restart_timeout
                        ):
                            raise MaxFrameError(
                                "Failed to get valid response from service. "
                                f"Session {self._session_id}."
                            ) from None
                        await asyncio.sleep(timeout_val)
                        continue

                    if dag_info is None:
                        raise SystemError(
                            f"Cannot find DAG with ID {dag_id} in session {session_id}"
                        )
                    progress.value = dag_info.progress
                    if dag_info.status != DagStatus.RUNNING:
                        break
                    await asyncio.sleep(timeout_val)
            except asyncio.CancelledError:
                dag_info = await self.ensure_async_call(self._caller.cancel_dag, dag_id)
                if dag_info.status != DagStatus.CANCELLED:  # pragma: no cover
                    raise
            finally:
                if dag_info.status == DagStatus.SUCCEEDED:
                    progress.value = 1.0
                elif dag_info.status == DagStatus.FAILED:
                    dag_info.error_info.reraise()

            if dag_info.status in (DagStatus.RUNNING, DagStatus.CANCELLED):
                return

            for key, result_info in dag_info.tileable_to_result_infos.items():
                t = key_to_tileables[key]
                fetcher = get_fetcher_cls(result_info.result_type)(self._odps_entry)
                await fetcher.update_tileable_meta(t, result_info)
                self._tileable_to_infos[t] = result_info

    def _get_data_tileable_and_indexes(
        self, tileable: TileableType
    ) -> Tuple[TileableType, List[Union[slice, Integral]]]:
        from maxframe.dataframe.indexing.iloc import (
            DataFrameIlocGetItem,
            SeriesIlocGetItem,
        )
        from maxframe.tensor.indexing import TensorIndex

        slice_op_types = TensorIndex, DataFrameIlocGetItem, SeriesIlocGetItem

        if isinstance(tileable, Entity):
            tileable = tileable.data

        indexes = None
        while tileable not in self._tileable_to_infos:
            # if tileable's op is slice, try to check input
            if isinstance(tileable.op, slice_op_types):
                indexes = tileable.op.indexes
                tileable = tileable.inputs[0]
                if not all(isinstance(index, (slice, Integral)) for index in indexes):
                    raise ValueError("Only support fetch data slices")
            else:
                raise ValueError(f"Cannot fetch unexecuted tileable: {tileable!r}")

        return tileable, indexes

    async def fetch(self, *tileables, **kwargs) -> list:
        results = []
        tileables = [
            tileable.data if isinstance(tileable, Entity) else tileable
            for tileable in tileables
        ]
        with enter_mode(build=True):
            for tileable in tileables:
                data_tileable, indexes = self._get_data_tileable_and_indexes(tileable)
                info = self._tileable_to_infos[data_tileable]
                fetcher = get_fetcher_cls(info.result_type)(self._odps_entry)
                results.append(await fetcher.fetch(data_tileable, info, indexes))
        return results

    async def decref(self, *tileable_keys):
        return await self.ensure_async_call(self._caller.decref, list(tileable_keys))

    async def destroy(self):
        await self.ensure_async_call(self._caller.delete_session)
        await super().destroy()

    async def _get_ref_counts(self) -> Dict[str, int]:
        pass

    async def fetch_tileable_op_logs(
        self,
        tileable_op_key: str,
        offsets: Union[Dict[str, List[int]], str, int],
        sizes: Union[Dict[str, List[int]], str, int],
    ) -> Dict:
        pass

    async def get_total_n_cpu(self):
        pass

    async def get_cluster_versions(self) -> List[str]:
        raise NotImplementedError

    async def get_web_endpoint(self) -> Optional[str]:
        raise NotImplementedError

    async def create_remote_object(
        self, session_id: str, name: str, object_cls, *args, **kwargs
    ):
        raise NotImplementedError

    async def get_remote_object(self, session_id: str, name: str):
        raise NotImplementedError

    async def destroy_remote_object(self, session_id: str, name: str):
        raise NotImplementedError

    async def create_mutable_tensor(
        self,
        shape: tuple,
        dtype: Union[np.dtype, str],
        name: str = None,
        default_value: Union[int, float] = 0,
        chunk_size: Union[int, Tuple] = None,
    ):
        raise NotImplementedError

    async def get_mutable_tensor(self, name: str):
        raise NotImplementedError

    async def get_logview_address(self, hours=None) -> Optional[str]:
        return await self.get_dag_logview_address(None, hours)

    async def get_dag_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
        return await self.ensure_async_call(
            self._caller.get_logview_address, dag_id, hours
        )

    async def _show_logview_address(self, dag_id=None, hours=None):
        identity = f"Session ID: {self._session_id}"
        if dag_id:
            identity += f", DAG ID: {dag_id}"

        logview_addr = await self.get_dag_logview_address(dag_id, hours)
        if logview_addr:
            logger.info("%s, Logview: %s", identity, logview_addr)
        else:
            logger.info("%s, Logview address does not exist", identity)


class MaxFrameRestCaller(MaxFrameServiceCaller):
    _client: FrameDriverClient
    _session_id: Optional[str]

    def __init__(self, odps_entry: ODPS, client: FrameDriverClient):
        self._odps_entry = odps_entry
        self._client = client
        self._session_id = None

    async def create_session(self) -> SessionInfo:
        info = await self._client.create_session(options.to_dict(remote_only=True))
        self._session_id = info.session_id
        return info

    async def delete_session(self) -> None:
        await self._client.delete_session(self._session_id)

    async def submit_dag(
        self,
        dag: TileableGraph,
        managed_input_infos: Dict[str, ResultInfo] = None,
        new_settings: Dict[str, Any] = None,
    ) -> DagInfo:
        return await self._client.submit_dag(
            self._session_id, dag, managed_input_infos, new_settings=new_settings
        )

    async def get_dag_info(self, dag_id: str) -> DagInfo:
        return await self._client.get_dag_info(self._session_id, dag_id)

    async def cancel_dag(self, dag_id: str) -> DagInfo:
        return await self._client.cancel_dag(self._session_id, dag_id)

    async def decref(self, tileable_keys: List[str]) -> None:
        return await self._client.decref(self._session_id, tileable_keys)


class MaxFrameRestSession(MaxFrameSession):
    schemes = [RESTFUL_SESSION_INSECURE_SCHEME, RESTFUL_SESSION_SECURE_SCHEME]

    def __init__(
        self,
        address: str,
        session_id: str,
        odps_entry: Optional[ODPS] = None,
        timeout: Optional[float] = None,
        new: bool = True,
        **kwargs,
    ):
        parsed_endpoint = urlparse(address)
        scheme = (
            "http"
            if parsed_endpoint.scheme == RESTFUL_SESSION_INSECURE_SCHEME
            else "https"
        )
        real_endpoint = address.replace(f"{parsed_endpoint.scheme}://", f"{scheme}://")

        super().__init__(
            real_endpoint, session_id, odps_entry=odps_entry, timeout=timeout, **kwargs
        )

    @classmethod
    def _create_caller(cls, odps_entry: ODPS, address: str, **kwargs):
        return MaxFrameRestCaller(odps_entry, FrameDriverClient(address))


def register_session_schemes(overwrite: bool = False):
    MaxFrameRestSession.register_schemes(overwrite=overwrite)
