# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import string
from collections import defaultdict
from typing import Any, Dict, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from odps import types as odps_types
from pandas.api import types as pd_types

from ...core import TILEABLE_TYPE, OutputType
from ...lib.dtypes_extension import ArrowDtype
from ...protocol import DataFrameTableMeta
from ...tensor.core import TENSOR_TYPE

_TEMP_TABLE_PREFIX = "tmp_mf_"
DEFAULT_SINGLE_INDEX_NAME = "_idx_0"

_arrow_to_odps_types = {
    pa.string(): odps_types.string,
    pa.binary(): odps_types.binary,
    pa.int8(): odps_types.tinyint,
    pa.int16(): odps_types.smallint,
    pa.int32(): odps_types.int_,
    pa.int64(): odps_types.bigint,
    pa.bool_(): odps_types.boolean,
    pa.float32(): odps_types.float_,
    pa.float64(): odps_types.double,
    pa.date32(): odps_types.date,
    pa.timestamp("ms"): odps_types.datetime,
    pa.timestamp("us"): odps_types.timestamp,
    pa.timestamp("ns"): odps_types.timestamp,
}

_odps_type_to_arrow = {
    odps_types.string: pa.string(),
    odps_types.binary: pa.binary(),
    odps_types.tinyint: pa.int8(),
    odps_types.smallint: pa.int16(),
    odps_types.int_: pa.int32(),
    odps_types.bigint: pa.int64(),
    odps_types.boolean: pa.bool_(),
    odps_types.float_: pa.float32(),
    odps_types.double: pa.float64(),
    odps_types.date: pa.date32(),
    odps_types.datetime: pa.timestamp("ms"),
    odps_types.json: pa.string(),
    odps_types.timestamp: pa.timestamp("ns"),
    odps_types.timestamp_ntz: pa.timestamp("ns"),
}

_based_for_pandas_pa_dtypes = Union[pa.MapType]


def is_based_for_pandas_dtype(dtype: pa.DataType) -> bool:
    """
    Check whether the arrow type is based for one pandas data type.
    If true, we should make sure the environment support ArrowDtype.
    """
    if not isinstance(dtype, _based_for_pandas_pa_dtypes):
        return False

    if ArrowDtype is None:
        raise ImportError("ArrowDtype is not supported in current environment")
    return True


def pandas_types_to_arrow_schema(df_obj: pd.DataFrame) -> pa.Schema:
    """
    This one is only called when a pandas DataFrame is written to ODPS. So we can check
    whether the ArrowDtype is supported.
    """
    schema = pa.Schema.from_pandas(df_obj, preserve_index=False)
    for idx, col_dtype in enumerate(df_obj.dtypes.items()):
        if ArrowDtype is not None and isinstance(col_dtype[1], ArrowDtype):
            schema.set(idx, pa.field(col_dtype[0], col_dtype[1].pyarrow_dtype))
    return schema


def arrow_type_to_odps_type(
    arrow_type: pa.DataType, col_name: str, unknown_as_string: bool = False
) -> odps_types.DataType:
    if arrow_type in _arrow_to_odps_types:
        return _arrow_to_odps_types[arrow_type]
    elif isinstance(arrow_type, pa.ListType):
        return odps_types.Array(
            arrow_type_to_odps_type(arrow_type.value_type, col_name, unknown_as_string)
        )
    elif isinstance(arrow_type, pa.MapType):
        return odps_types.Map(
            arrow_type_to_odps_type(arrow_type.key_type, col_name, unknown_as_string),
            arrow_type_to_odps_type(arrow_type.item_type, col_name, unknown_as_string),
        )
    elif isinstance(arrow_type, pa.StructType):
        type_dict = {}
        for idx in range(arrow_type.num_fields):
            field = arrow_type[idx]
            type_dict[field.name] = arrow_type_to_odps_type(
                field.type, col_name, unknown_as_string
            )
        return odps_types.Struct(type_dict)
    elif isinstance(arrow_type, pa.Decimal128Type):
        return odps_types.Decimal(arrow_type.precision, arrow_type.scale)

    if unknown_as_string:
        return odps_types.string
    else:
        raise TypeError(
            "Unknown type {}, column name is {},"
            "specify `unknown_as_string=True` "
            "or `as_type` to set column dtype".format(arrow_type, col_name)
        )


def arrow_schema_to_odps_schema(
    arrow_schema: pa.Schema, unknown_as_string: bool = False
) -> odps_types.OdpsSchema:
    odps_cols = []
    for col_name, col_type in zip(arrow_schema.names, arrow_schema.types):
        col_odps_type = arrow_type_to_odps_type(
            col_type, col_name, unknown_as_string=unknown_as_string
        )
        odps_cols.append(odps_types.Column(col_name, col_odps_type))

    return odps_types.OdpsSchema(odps_cols)


def odps_type_to_arrow_type(
    odps_type: odps_types.DataType, col_name: str
) -> pa.DataType:
    if odps_type in _odps_type_to_arrow:
        col_type = _odps_type_to_arrow[odps_type]
    else:
        if isinstance(odps_type, odps_types.Array):
            col_type = pa.list_(odps_type_to_arrow_type(odps_type.value_type, col_name))
        elif isinstance(odps_type, odps_types.Map):
            col_type = pa.map_(
                odps_type_to_arrow_type(odps_type.key_type, col_name),
                odps_type_to_arrow_type(odps_type.value_type, col_name),
            )
        elif isinstance(odps_type, odps_types.Struct):
            fields = [
                (k, odps_type_to_arrow_type(v, col_name))
                for k, v in odps_type.field_types.items()
            ]
            col_type = pa.struct(fields)
        elif isinstance(odps_type, odps_types.Decimal):
            if odps_type.name == "decimal":
                # legacy decimal data without precision or scale
                # precision data from internal compat mode
                col_type = pa.decimal128(38, 18)
            else:
                col_type = pa.decimal128(
                    odps_type.precision or odps_types.Decimal._max_precision,
                    odps_type.scale or odps_types.Decimal._max_scale,
                )
        elif isinstance(odps_type, (odps_types.Varchar, odps_types.Char)):
            col_type = pa.string()
        else:
            raise TypeError(
                "Unsupported type {}, column name is {}".format(odps_type, col_name)
            )
    return col_type


def odps_schema_to_arrow_schema(
    odps_schema: odps_types.OdpsSchema, with_partitions: bool = False
) -> pa.Schema:
    arrow_schema = []
    cols = odps_schema.columns if with_partitions else odps_schema.simple_columns
    for col in cols:
        col_name = col.name
        col_type = odps_type_to_arrow_type(col.type, col_name)
        arrow_schema.append(pa.field(col_name, col_type))

    return pa.schema(arrow_schema)


def odps_schema_to_pandas_dtypes(
    odps_schema: odps_types.OdpsSchema, with_partitions: bool = False
) -> pd.Series:
    arrow_schema = odps_schema_to_arrow_schema(
        odps_schema, with_partitions=with_partitions
    )
    return arrow_table_to_pandas_dataframe(arrow_schema.empty_table()).dtypes


def arrow_table_to_pandas_dataframe(
    table: pa.Table, meta: DataFrameTableMeta = None
) -> pd.DataFrame:
    df = table.to_pandas(
        types_mapper=lambda x: (
            ArrowDtype(x) if is_based_for_pandas_dtype(x) else None
        ),
        ignore_metadata=True,
    )
    if not meta:
        return df

    # If meta is passed, we should convert the dtypes to match the ones in the meta
    converted_column_dtypes = dict()
    for source_dtype, target_col, target_dtype in zip(
        df.dtypes.values,
        df.columns,
        list(meta.pd_index_dtypes.values) + list(meta.pd_column_dtypes.values),
    ):
        if source_dtype != target_dtype:
            # Converting tz-aware dtype to tz-native dtype is a special case.
            # In numpy1.19, we can't use numpy.dtype.DateTime64Dtype
            if (
                isinstance(source_dtype, pd.DatetimeTZDtype)
                and isinstance(target_dtype, np.dtype)
                and target_dtype.name.startswith("datetime64")
            ):
                df[target_col] = df[target_col].dt.tz_localize(None)
            else:
                converted_column_dtypes[target_col] = target_dtype

    if converted_column_dtypes:
        df = df.astype(converted_column_dtypes)

    return df


def pandas_dataframe_to_arrow_table(df: pd.DataFrame, nthreads=1) -> pa.Table:
    schema = pandas_types_to_arrow_schema(df)
    return pa.Table.from_pandas(
        df, schema=schema, nthreads=nthreads, preserve_index=False
    )


def is_scalar_object(df_obj: Any) -> bool:
    return (
        isinstance(df_obj, TENSOR_TYPE) and df_obj.shape == ()
    ) or pd_types.is_scalar(df_obj)


def _scalar_as_index(df_obj: Any) -> pd.Index:
    if isinstance(df_obj, TILEABLE_TYPE):
        return pd.Index([], dtype=df_obj.dtype)
    else:
        return pd.Index([df_obj])[:0]


def pandas_to_odps_schema(
    df_obj: Any,
    unknown_as_string: bool = False,
    ignore_index=False,
) -> Tuple[odps_types.OdpsSchema, DataFrameTableMeta]:
    from ... import dataframe as md
    from .arrow import pandas_to_arrow

    if is_scalar_object(df_obj):
        empty_index = None
    elif hasattr(df_obj, "index_value"):
        empty_index = df_obj.index_value.to_pandas()[:0]
    elif not isinstance(df_obj, pd.Index):
        empty_index = df_obj.index[:0]
    else:
        empty_index = df_obj[:0]

    if hasattr(df_obj, "columns_value"):
        empty_columns = df_obj.dtypes.index
    elif hasattr(df_obj, "columns"):
        empty_columns = df_obj.columns
    else:
        empty_columns = None

    ms_cols = None
    if isinstance(df_obj, (md.DataFrame, pd.DataFrame)):
        empty_df_obj = pd.DataFrame(
            [], columns=empty_columns, index=empty_index
        ).astype(df_obj.dtypes)
        ms_cols = [
            col for col, dt in df_obj.dtypes.items() if dt == np.dtype("datetime64[ms]")
        ]
    elif isinstance(df_obj, (md.Series, pd.Series)):
        empty_df_obj = pd.Series([], name=df_obj.name, index=empty_index).astype(
            df_obj.dtype
        )
        ms_cols = df_obj.dtype == np.dtype("datetime64[ms]")
    elif isinstance(df_obj, (md.Index, pd.Index)):
        empty_df_obj = empty_index
        if isinstance(empty_index, pd.MultiIndex):
            ms_cols = [
                idx
                for idx, dt in enumerate(empty_index.dtypes.values)
                if dt == np.dtype("datetime64[ms]")
            ]
        else:
            ms_cols = df_obj.dtype == np.dtype("datetime64[ms]")
    else:
        empty_df_obj = df_obj

    arrow_data, table_meta = pandas_to_arrow(
        empty_df_obj, ignore_index=ignore_index, ms_cols=ms_cols
    )
    return (
        arrow_schema_to_odps_schema(
            arrow_data.schema, unknown_as_string=unknown_as_string
        ),
        table_meta,
    )


def build_table_column_name(
    col_idx: int, pd_col_name: Any, records: Dict[str, str]
) -> str:
    """
    Convert column name to MaxCompute acceptable names

    Parameters
    ----------
    col_idx:
        index of the column
    pd_col_name:
        column name in pandas
    records:
        record for existing columns

    Returns
    -------
        converted column name
    """

    def _is_col_name_legal(name: str):
        if len(name) < 1 or len(name) > 128:
            return False
        if name[0] not in string.ascii_letters and name[0] != "_":
            return False
        for ch in name:
            if ch not in string.digits and ch not in string.ascii_letters and ch != "_":
                return False
        return True

    try:
        return records[pd_col_name]
    except KeyError:
        pass

    if isinstance(pd_col_name, str):
        col_name = pd_col_name
    elif isinstance(pd_col_name, tuple):
        col_name = "_".join(str(x) for x in pd_col_name)
    else:
        col_name = str(pd_col_name)

    col_name = col_name.lower()
    if not _is_col_name_legal(col_name):
        col_name = f"_column_{col_idx}"

    records[pd_col_name] = col_name
    return col_name


def build_dataframe_table_meta(
    df_obj: Any, ignore_index: bool = False
) -> DataFrameTableMeta:
    from ... import dataframe as md

    col_to_count = defaultdict(lambda: 0)
    col_to_idx = defaultdict(lambda: 0)
    pd_col_to_col_name = dict()
    if isinstance(df_obj, (md.DataFrame, pd.DataFrame)):
        obj_type = OutputType.dataframe
    elif isinstance(df_obj, (md.Series, pd.Series)):
        obj_type = OutputType.series
    elif isinstance(df_obj, (md.Index, pd.Index)):
        obj_type = OutputType.index
    elif is_scalar_object(df_obj):
        obj_type = OutputType.scalar
    else:  # pragma: no cover
        raise TypeError(f"Cannot accept type {type(df_obj)}")

    if obj_type == OutputType.scalar:
        pd_dtypes = pd.Series([])
        column_index_names = []
        index_obj = _scalar_as_index(df_obj)
    elif obj_type == OutputType.index:
        pd_dtypes = pd.Series([])
        column_index_names = []
        index_obj = df_obj
    elif obj_type == OutputType.series:
        pd_dtypes = pd.Series([df_obj.dtype], index=[df_obj.name])
        column_index_names = [None]
        index_obj = df_obj.index
    else:
        pd_dtypes = df_obj.dtypes
        column_index_names = list(pd_dtypes.index.names)
        index_obj = df_obj.index

    if isinstance(df_obj, TILEABLE_TYPE):
        table_name = _TEMP_TABLE_PREFIX + str(df_obj.key)
    else:
        table_name = None

    sql_columns = [None] * len(pd_dtypes)
    pd_col_names = pd_dtypes.index
    if obj_type == OutputType.series and df_obj.name is None:
        # use special table column name for series
        pd_col_names = ["_data"]
    for idx, col in enumerate(pd_col_names):
        sql_columns[idx] = col_name = build_table_column_name(
            idx, col, pd_col_to_col_name
        )
        col_to_count[col_name] += 1

    final_sql_columns = []
    for col in sql_columns:
        if col_to_count[col] > 1:
            col_name = f"{col}_{col_to_idx[col]}"
            col_to_idx[col] += 1
            while col_name in col_to_count:
                col_name = f"{col}_{col_to_idx[col]}"
                col_to_idx[col] += 1
            final_sql_columns.append(col_name)
        else:
            final_sql_columns.append(col)

    if hasattr(index_obj, "index_value"):
        pd_index_val = index_obj.index_value.to_pandas()
    else:
        pd_index_val = index_obj

    level_dtypes = [
        pd_index_val.get_level_values(level).dtype
        for level in range(pd_index_val.nlevels)
    ]
    index_dtypes = pd.Series(level_dtypes, index=pd_index_val.names)

    if ignore_index and obj_type != OutputType.index:
        table_index_column_names = []
        pd_index_dtypes = pd.Series([], index=[])
    else:
        table_index_column_names = [f"_idx_{i}" for i in range(len(index_obj.names))]
        pd_index_dtypes = index_dtypes

    return DataFrameTableMeta(
        table_name=table_name,
        type=obj_type,
        table_column_names=final_sql_columns,
        table_index_column_names=table_index_column_names,
        pd_column_dtypes=pd_dtypes,
        pd_column_level_names=column_index_names,
        pd_index_dtypes=pd_index_dtypes,
    )
