# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import logging
import re
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from odps import ODPS
from odps.types import Column, OdpsSchema, validate_data_type
from odps.utils import split_sql_by_semicolon

from ... import opcodes
from ...config import options
from ...core import OutputType
from ...core.graph import DAG
from ...io.odpsio import odps_schema_to_pandas_dtypes
from ...serialization.serializables import (
    AnyField,
    BoolField,
    DictField,
    FieldTypes,
    Int64Field,
    ListField,
    SeriesField,
    StringField,
)
from ...utils import is_empty
from ..utils import parse_index
from .core import ColumnPruneSupportedDataSourceMixin, IncrementalIndexDatasource

logger = logging.getLogger(__name__)

_DEFAULT_ANONYMOUS_COL_PREFIX = "_anon_col_"

_EXPLAIN_DEPENDS_REGEX = re.compile(r"([^\s]+) depends on: ([^\n]+)")
_EXPLAIN_JOB_REGEX = re.compile(r"(\S+) is root job")
_EXPLAIN_TASKS_HEADER_REGEX = re.compile(r"In Job ([^:]+):")
_EXPLAIN_ROOT_TASKS_REGEX = re.compile(r"root Tasks: (.+)")
_EXPLAIN_TASK_REGEX = re.compile(r"In Task ([^:]+)")
_EXPLAIN_TASK_SCHEMA_REGEX = re.compile(
    r"In Task ([^:]+)[\S\s]+FS: output: ([^\n #]+)[\s\S]+schema:\s+([\S\s]+)$",
    re.MULTILINE,
)
_EXPLAIN_COLUMN_REGEX = re.compile(r"([^\(]+) \(([^\n]+)\)(?:| AS ([^ ]+))(?:\n|$)")
_ANONYMOUS_COL_REGEX = re.compile(r"^_c(\d+)$")

_SIMPLE_SCHEMA_COLS_REGEX = re.compile(r"SELECT (([^:]+:[^, ]+[, ]*)+)FROM")
_SIMPLE_SCHEMA_COL_REGEX = re.compile(r"([^ \.\)]+):([^ ]+)")


@dataclasses.dataclass
class DependencySector:
    roots: List[str]
    dependencies: List[Tuple[str, str]]

    def build_dag(self) -> DAG:
        dag = DAG()
        for r in self.roots:
            dag.add_node(r)
        for v_from, v_to in self.dependencies:
            dag.add_node(v_from)
            dag.add_node(v_to)
            dag.add_edge(v_from, v_to)
        return dag


@dataclasses.dataclass
class JobsSector(DependencySector):
    jobs: Dict[str, "TasksSector"] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class TasksSector(DependencySector):
    job_name: str
    tasks: Dict[str, "TaskSector"] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class ColumnSchema:
    column_name: str
    column_type: str
    column_alias: Optional[str]


@dataclasses.dataclass
class TaskSector:
    job_name: str
    task_name: str
    output_target: Optional[str]
    schema: List[ColumnSchema]


def _split_explain_string(explain_string: str) -> List[str]:
    parts = explain_string.split("\n\n")
    final_parts = []
    grouped = []
    for part in parts:
        part = part.strip("\n")
        if grouped and not part.startswith(" "):
            final_parts.append("\n\n".join(grouped).strip())
            grouped = []
        grouped.append(part)
    if grouped:
        final_parts.append("\n\n".join(grouped).strip())
    return final_parts


def _find_all_deps(sector: str) -> List[Tuple[str, str]]:
    deps = []
    for match in _EXPLAIN_DEPENDS_REGEX.findall(sector):
        descendant = match[0]
        for r in match[1].split(","):
            deps.append((r.strip(), descendant))
    return deps


def _resolve_jobs_sector(sector: str) -> JobsSector:
    match = _EXPLAIN_JOB_REGEX.search(sector)
    roots = [r.strip() for r in match.group(1).split(",")]
    deps = _find_all_deps(sector)
    return JobsSector(roots, deps)


def _resolve_tasks_sector(sector: str) -> TasksSector:
    match = _EXPLAIN_ROOT_TASKS_REGEX.search(sector)
    roots = [r.strip() for r in match.group(1).split(",")]

    match = _EXPLAIN_TASKS_HEADER_REGEX.search(sector)
    job_name = match.group(1)

    deps = _find_all_deps(sector)
    return TasksSector(roots, deps, job_name)


def _resolve_task_sector(job_name: str, sector: str) -> TaskSector:
    match = _EXPLAIN_TASK_REGEX.match(sector)
    task_name = match.group(1)

    match = _EXPLAIN_TASK_SCHEMA_REGEX.match(sector)
    if match is None:
        return TaskSector(job_name, task_name, None, [])

    out_target = match.group(2)
    out_schema = match.group(3)

    schemas = []
    for match in _EXPLAIN_COLUMN_REGEX.findall(out_schema):
        col_name, data_type, alias = match
        schemas.append(ColumnSchema(col_name.strip(), data_type.strip(), alias.strip()))
    return TaskSector(job_name, task_name, out_target, schemas)


def _select_task_prefix(sector: TasksSector, prefix: str) -> List[TaskSector]:
    if prefix in sector.tasks:
        return [sector.tasks[prefix]]
    return [v for k, v in sector.tasks.items() if k.startswith(prefix + "_")]


def _parse_full_explain(explain_string: str) -> OdpsSchema:
    sectors = _split_explain_string(explain_string)
    jobs_sector = tasks_sector = None

    for sector in sectors:
        if _EXPLAIN_JOB_REGEX.search(sector):
            jobs_sector = _resolve_jobs_sector(sector)
        elif _EXPLAIN_TASKS_HEADER_REGEX.search(sector):
            tasks_sector = _resolve_tasks_sector(sector)
            assert jobs_sector is not None
            jobs_sector.jobs[tasks_sector.job_name] = tasks_sector
        elif _EXPLAIN_TASK_REGEX.search(sector):
            assert tasks_sector is not None
            task_sector = _resolve_task_sector(tasks_sector.job_name, sector)
            tasks_sector.tasks[task_sector.task_name] = task_sector

    job_dag = jobs_sector.build_dag()
    indep_job_names = list(job_dag.iter_indep(reverse=True))
    schema_signatures = dict()
    for job_name in indep_job_names:
        tasks_sector = jobs_sector.jobs[job_name]
        task_dag = tasks_sector.build_dag()
        indep_task_names = list(task_dag.iter_indep(reverse=True))
        for task_name in indep_task_names:
            for task_sector in _select_task_prefix(tasks_sector, task_name):
                if not task_sector.schema:  # pragma: no cover
                    raise ValueError("Cannot detect output schema")
                if task_sector.output_target != "Screen":
                    raise ValueError("The SQL statement should be an instant query")
                sig_tuples = sorted(
                    [
                        (c.column_alias or c.column_name, c.column_type)
                        for c in task_sector.schema
                    ]
                )
                schema_signatures[hash(tuple(sig_tuples))] = task_sector.schema
    if len(schema_signatures) != 1:
        raise ValueError("Only one final task is allowed in SQL statement")
    schema = list(schema_signatures.values())[0]
    cols = [
        Column(c.column_alias or c.column_name, validate_data_type(c.column_type))
        for c in schema
    ]
    return OdpsSchema(cols)


def _parse_simple_explain(explain_string: str) -> OdpsSchema:
    fields_match = _SIMPLE_SCHEMA_COLS_REGEX.search(explain_string)
    if not fields_match:
        raise ValueError("Cannot detect output table schema")

    fields_str = fields_match.group(1)
    cols = []
    for field, type_name in _SIMPLE_SCHEMA_COL_REGEX.findall(fields_str):
        cols.append(Column(field, validate_data_type(type_name.rstrip(","))))
    return OdpsSchema(cols)


def _parse_explained_schema(explain_string: str) -> OdpsSchema:
    if explain_string.startswith("AdhocSink"):
        return _parse_simple_explain(explain_string)
    else:
        return _parse_full_explain(explain_string)


def _build_explain_sql(sql_stmt: str, no_split: bool = False) -> str:
    if no_split:
        return "EXPLAIN " + sql_stmt
    sql_parts = split_sql_by_semicolon(sql_stmt)
    if not sql_parts:
        raise ValueError(f"Cannot explain SQL statement {sql_stmt}")
    sql_parts[-1] = "EXPLAIN " + sql_parts[-1]
    return "\n".join(sql_parts)


class DataFrameReadODPSQuery(
    IncrementalIndexDatasource,
    ColumnPruneSupportedDataSourceMixin,
):
    _op_type_ = opcodes.READ_ODPS_QUERY

    query = StringField("query")
    dtypes = SeriesField("dtypes", default=None)
    columns = AnyField("columns", default=None)
    nrows = Int64Field("nrows", default=None)
    use_arrow_dtype = BoolField("use_arrow_dtype", default=None)
    string_as_binary = BoolField("string_as_binary", default=None)
    index_columns = ListField("index_columns", FieldTypes.string, default=None)
    index_dtypes = SeriesField("index_dtypes", default=None)
    column_renames = DictField("column_renames", default=None)

    def get_columns(self):
        return self.columns or list(self.dtypes.index)

    def set_pruned_columns(self, columns, *, keep_order=None):  # pragma: no cover
        self.columns = columns

    def __call__(self, chunk_bytes=None, chunk_size=None):
        if is_empty(self.index_columns):
            index_value = parse_index(pd.RangeIndex(0))
        elif len(self.index_columns) == 1:
            index_value = parse_index(
                pd.Index([], name=self.index_columns[0]).astype(
                    self.index_dtypes.iloc[0]
                )
            )
        else:
            idx = pd.MultiIndex.from_frame(
                pd.DataFrame([], columns=self.index_columns).astype(self.index_dtypes)
            )
            index_value = parse_index(idx)

        if self.dtypes is not None:
            columns_value = parse_index(self.dtypes.index, store_data=True)
            shape = (np.nan, len(self.dtypes))
        else:
            columns_value = None
            shape = (np.nan, np.nan)

        self.output_types = [OutputType.dataframe]
        return self.new_tileable(
            [],
            None,
            shape=shape,
            dtypes=self.dtypes,
            index_value=index_value,
            columns_value=columns_value,
            chunk_bytes=chunk_bytes,
            chunk_size=chunk_size,
        )


def read_odps_query(
    query: str,
    odps_entry: ODPS = None,
    index_col: Union[None, str, List[str]] = None,
    string_as_binary: bool = None,
    sql_hints: Dict[str, str] = None,
    anonymous_col_prefix: str = _DEFAULT_ANONYMOUS_COL_PREFIX,
    skip_schema: bool = False,
    **kw,
):
    """
    Read data from a MaxCompute (ODPS) query into DataFrame.

    Supports specifying some columns as indexes. If not specified, RangeIndex
    will be generated.

    Parameters
    ----------
    query: str
        MaxCompute SQL statement.
    index_col: Union[None, str, List[str]]
        Columns to be specified as indexes.
    string_as_binary: bool, optional
        Whether to convert string columns to binary.
    sql_hints: Dict[str, str], optional
        User specified SQL hints.
    anonymous_col_prefix: str, optional
        Prefix for anonymous columns, '_anon_col_' by default.
    skip_schema: bool, optional
        Skip resolving output schema before execution. Once this is configured,
        the output DataFrame cannot be inputs of other DataFrame operators
        before execution.

    Returns
    -------
    result: DataFrame
        DataFrame read from MaxCompute (ODPS) table
    """
    no_split_sql = kw.pop("no_split_sql", False)

    hints = options.sql.settings.copy() or {}
    if sql_hints:
        hints.update(sql_hints)

    odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments()
    if odps_entry is None:
        raise ValueError(
            "Need to provide an odps_entry argument or hold a default ODPS entry."
        )

    if options.session.enable_schema or odps_entry.is_schema_namespace_enabled():
        hints["odps.namespace.schema"] = "true"
        hints["odps.sql.allow.namespace.schema"] = "true"

    hints["odps.sql.submit.mode"] = "script"
    # fixme workaround for multi-stage split process
    hints["odps.sql.object.table.split.by.object.size.enabled"] = "false"

    if odps_entry is None:
        raise ValueError("Missing odps_entry parameter")

    col_renames = {}
    if not skip_schema:
        explain_stmt = _build_explain_sql(query, no_split=no_split_sql)
        inst = odps_entry.execute_sql(explain_stmt, hints=hints)
        logger.debug("Explain instance ID: %s", inst.id)
        explain_str = list(inst.get_task_results().values())[0]

        try:
            odps_schema = _parse_explained_schema(explain_str)
        except BaseException as ex:
            exc = ValueError(
                f"Failed to obtain schema from SQL explain: {ex!r}"
                f"\nExplain instance ID: {inst.id}"
            )
            raise exc.with_traceback(ex.__traceback__) from None

        new_columns = []
        for col in odps_schema.columns:
            anon_match = _ANONYMOUS_COL_REGEX.match(col.name)
            if anon_match and col.name not in query:
                new_name = anonymous_col_prefix + anon_match.group(1)
                col_renames[col.name] = new_name
                new_columns.append(Column(new_name, col.type))
            else:
                new_columns.append(col)

        dtypes = odps_schema_to_pandas_dtypes(OdpsSchema(new_columns))
    else:
        dtypes = None

    if not index_col:
        index_dtypes = None
    else:
        if dtypes is None:
            raise ValueError("Cannot configure index_col when skip_schema is True")

        if isinstance(index_col, str):
            index_col = [index_col]
        index_col_set = set(index_col)
        data_cols = [c for c in dtypes.index if c not in index_col_set]
        idx_dtype_vals = [dtypes[c] for c in index_col]
        col_dtype_vals = [dtypes[c] for c in data_cols]
        index_dtypes = pd.Series(idx_dtype_vals, index=index_col)
        dtypes = pd.Series(col_dtype_vals, index=data_cols)

    chunk_bytes = kw.pop("chunk_bytes", None)
    chunk_size = kw.pop("chunk_size", None)
    op = DataFrameReadODPSQuery(
        query=query,
        dtypes=dtypes,
        use_arrow_dtype=kw.pop("use_arrow_dtype", True),
        string_as_binary=string_as_binary,
        index_columns=index_col,
        index_dtypes=index_dtypes,
        column_renames=col_renames,
        no_split_sql=no_split_sql,
    )
    return op(chunk_bytes=chunk_bytes, chunk_size=chunk_size)
