# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from numbers import Integral

import numpy as np
import pandas as pd
from pandas.core.dtypes.cast import find_common_type
from pandas.core.indexing import IndexingError

from ... import opcodes
from ...config import options
from ...core import ENTITY_TYPE, OutputType
from ...serialization.serializables import AnyField, KeyField, ListField
from ...tensor import asarray
from ...tensor.indexing.core import calc_shape
from ..operators import DATAFRAME_TYPE, DataFrameOperator, DataFrameOperatorMixin
from ..utils import indexing_index_value

_ILOC_ERROR_MSG = (
    "Location based indexing can only have [integer, "
    "integer slice (START point is INCLUDED, END point is EXCLUDED), "
    "listlike of integers, boolean array] types"
)


def process_iloc_indexes(inp, indexes):
    ndim = inp.ndim

    if not isinstance(indexes, tuple):
        indexes = (indexes,)
    if len(indexes) < ndim:
        indexes += (slice(None),) * (ndim - len(indexes))
    if len(indexes) > ndim:
        raise IndexingError("Too many indexers")

    new_indexes = []
    # check each index
    for ax, index in enumerate(indexes):
        if isinstance(index, tuple):
            # a tuple should already have been caught by this point
            # so don't treat a tuple as a valid indexer
            raise IndexingError("Too many indexers")
        elif isinstance(index, slice):
            if any(v is not None for v in [index.start, index.stop, index.step]):
                pd_index = (
                    inp.index_value if ax == 0 else inp.columns_value
                ).to_pandas()
                for val in [index.start, index.stop, index.step]:
                    if val is not None:
                        try:
                            pd_index[val]  # check on the pandas
                        except IndexError:
                            pass
                        except TypeError:
                            raise TypeError(
                                f"cannot do slice indexing on {type(pd_index)} "
                                f"with these indexers [{val}] of {type(val)}"
                            )
            new_indexes.append(index)
        elif isinstance(index, (list, np.ndarray, pd.Series, ENTITY_TYPE)):
            if not isinstance(index, ENTITY_TYPE):
                index = np.asarray(index)
            else:
                index = asarray(index)
                if ax == 1:
                    # do not support tensor index on axis 1
                    # because if so, the dtypes and columns_value would be unknown
                    try:
                        index = index.fetch()
                    except (RuntimeError, ValueError):
                        raise NotImplementedError(
                            "indexer on axis columns cannot be non-executed tensor"
                        )
            if index.dtype != np.bool_:
                index = index.astype(np.int64)
            if index.ndim != 1:
                raise ValueError(
                    "Buffer has wrong number of dimensions "
                    f"(expected 1, got {index.ndim})"
                )
            new_indexes.append(index)
        elif isinstance(index, Integral):
            shape = inp.shape[ax]
            if not np.isnan(shape):
                if index < -shape or index >= shape:
                    raise IndexError("single positional indexer is out-of-bounds")
            new_indexes.append(index)
        else:
            raise ValueError(_ILOC_ERROR_MSG)

    return new_indexes


class DataFrameIloc:
    def __init__(self, obj):
        self._obj = obj

    def __getitem__(self, indexes):
        if isinstance(self._obj, DATAFRAME_TYPE):
            op = DataFrameIlocGetItem(indexes=process_iloc_indexes(self._obj, indexes))
        else:
            op = SeriesIlocGetItem(indexes=process_iloc_indexes(self._obj, indexes))
        return op(self._obj)

    def __setitem__(self, indexes, value):
        if not np.isscalar(value):
            raise NotImplementedError("Only scalar value is supported to set by iloc")

        if isinstance(self._obj, DATAFRAME_TYPE):
            op = DataFrameIlocSetItem(
                indexes=process_iloc_indexes(self._obj, indexes), value=value
            )
        else:
            op = SeriesIlocSetItem(
                indexes=process_iloc_indexes(self._obj, indexes), value=value
            )

        ret = op(self._obj)
        self._obj.data = ret.data


class HeadTailOptimizedOperatorMixin(DataFrameOperatorMixin):
    __slots__ = ()

    @classmethod
    def _is_head(cls, index0):
        return (
            (index0.start is None or index0.start == 0)
            and index0.stop is not None
            and index0.stop > 0
        )

    @classmethod
    def _is_tail(cls, index0):
        return index0.start is not None and index0.start < 0 and index0.stop is None

    @classmethod
    def _is_indexes_head_or_tail(cls, indexes):
        index0 = indexes[0]
        if not isinstance(index0, slice):
            # have to be slice
            return False
        if index0.step is not None and index0.step != 1:
            return False
        if len(indexes) == 2:
            if not isinstance(indexes[1], slice):
                return False
            if indexes[1] != slice(None):
                return False
        if cls._is_tail(index0):
            # tail
            return True
        if cls._is_head(index0):
            # head
            return True
        return False

    def can_be_optimized(self):
        return (
            self._is_indexes_head_or_tail(self.indexes)
            and self._is_head(self.indexes[0])
            and self.indexes[0].stop <= options.optimize.head_optimize_threshold
        )


class DataFrameIlocGetItem(DataFrameOperator, HeadTailOptimizedOperatorMixin):
    _op_type_ = opcodes.DATAFRAME_ILOC_GETITEM

    _input = KeyField("input")
    indexes = ListField("indexes", default=None)

    def __init__(self, gpu=None, sparse=False, output_types=None, **kw):
        super().__init__(gpu=gpu, sparse=sparse, _output_types=output_types, **kw)
        if not self.output_types:
            self.output_types = [OutputType.dataframe]

    @property
    def input(self):
        return self._input

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        inputs_iter = iter(self._inputs)
        self._input = next(inputs_iter)
        indexes = []
        for index in self.indexes:
            if isinstance(index, ENTITY_TYPE):
                indexes.append(next(inputs_iter))
            else:
                indexes.append(index)
        self.indexes = indexes

    def __call__(self, df):
        # Note [Fancy Index of Numpy and Pandas]
        #
        # The numpy and pandas.iloc have different semantic when processing fancy index:
        #
        # >>> np.ones((3,3))[[1,2],[1,2]]
        # array([1., 1.])
        #
        # >>> pd.DataFrame(np.ones((3,3))).iloc[[1,2],[1,2]]
        #    1    2
        # 1  1.0  1.0
        # 2  1.0  1.0
        #
        # Thus, we processing the index along two axis of DataFrame separately.
        shape0 = tuple(calc_shape((df.shape[0],), (self.indexes[0],)))
        shape1 = tuple(calc_shape((df.shape[1],), (self.indexes[1],)))

        inputs = [df] + [
            index for index in self.indexes if isinstance(index, ENTITY_TYPE)
        ]

        # NB: pandas only compresses the result to series when index on one of axis is integral
        if isinstance(self.indexes[1], Integral):
            shape = shape0
            dtype = df.dtypes.iloc[self.indexes[1]]
            index_value = indexing_index_value(df.index_value, self.indexes[0])
            if isinstance(self.indexes[0], Integral):
                # scalar
                return self.new_scalar(inputs, dtype=dtype)
            else:
                return self.new_series(
                    inputs,
                    shape=shape,
                    dtype=dtype,
                    index_value=index_value,
                    name=df.dtypes.index[self.indexes[1]],
                )
        elif isinstance(self.indexes[0], Integral):
            shape = shape1
            dtype = find_common_type(list(df.dtypes.iloc[self.indexes[1]].values))
            index_value = indexing_index_value(df.columns_value, self.indexes[1])
            return self.new_series(
                inputs, shape=shape, dtype=dtype, index_value=index_value
            )
        else:
            return self.new_dataframe(
                inputs,
                shape=shape0 + shape1,
                dtypes=df.dtypes.iloc[self.indexes[1]],
                index_value=indexing_index_value(df.index_value, self.indexes[0]),
                columns_value=indexing_index_value(
                    df.columns_value, self.indexes[1], store_data=True
                ),
            )


class SeriesIlocGetItem(DataFrameOperator, HeadTailOptimizedOperatorMixin):
    _op_module_ = "series"
    _op_type_ = opcodes.DATAFRAME_ILOC_GETITEM

    _input = KeyField("input")
    indexes = ListField("indexes", default=None)

    def __init__(self, gpu=None, sparse=False, output_types=None, **kw):
        super().__init__(gpu=gpu, sparse=sparse, _output_types=output_types, **kw)
        if not self.output_types:
            self.output_types = [OutputType.series]

    @property
    def input(self):
        return self._input

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)

        inputs_iter = iter(self._inputs)
        self._input = next(inputs_iter)

        indexes = []
        for index in self.indexes:
            if isinstance(index, ENTITY_TYPE):
                indexes.append(next(inputs_iter))
            else:
                indexes.append(index)
        self.indexes = indexes

    def __call__(self, series):
        if isinstance(self.indexes[0], Integral):
            return self.new_scalar([series], dtype=series.dtype)
        else:
            shape = tuple(calc_shape(series.shape, self.indexes))
            index_value = indexing_index_value(series.index_value, self.indexes[0])
            inputs = [series] + [
                index for index in self.indexes if isinstance(index, ENTITY_TYPE)
            ]
            return self.new_series(
                inputs,
                shape=shape,
                dtype=series.dtype,
                index_value=index_value,
                name=series.name,
            )


class IndexIlocGetItem(DataFrameOperator, DataFrameOperatorMixin):
    _op_module_ = "index"
    _op_type_ = opcodes.DATAFRAME_ILOC_GETITEM

    _input = KeyField("input")
    indexes = ListField("indexes", default=None)

    def __init__(self, gpu=None, sparse=False, output_types=None, **kw):
        super().__init__(gpu=gpu, sparse=sparse, _output_types=output_types, **kw)
        if not self.output_types:
            self.output_types = [OutputType.index]

    @property
    def input(self):
        return self._input

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)

        inputs_iter = iter(self._inputs)
        self._input = next(inputs_iter)

        indexes = []
        for index in self.indexes:
            if isinstance(index, ENTITY_TYPE):
                indexes.append(next(inputs_iter))
            else:
                indexes.append(index)
        self.indexes = indexes

    def __call__(self, idx):
        if isinstance(self.indexes[0], Integral):
            return self.new_scalar([idx], dtype=idx.dtype)
        else:
            shape = tuple(calc_shape(idx.shape, self.indexes))
            index_value = indexing_index_value(idx.index_value, self.indexes[0])
            inputs = [idx] + [
                index for index in self.indexes if isinstance(index, ENTITY_TYPE)
            ]
            return self.new_index(
                inputs,
                shape=shape,
                dtype=idx.dtype,
                index_value=index_value,
                name=idx.name,
            )


class DataFrameIlocSetItem(DataFrameOperator, DataFrameOperatorMixin):
    _op_type_ = opcodes.DATAFRAME_ILOC_SETITEM

    indexes = ListField("indexes", default=None)
    value = AnyField("value", default=None)

    def __init__(self, gpu=None, sparse=False, output_types=None, **kw):
        super().__init__(
            gpu=gpu,
            sparse=sparse,
            _output_types=output_types,
            **kw,
        )
        if not self.output_types:
            self.output_types = [OutputType.dataframe]

    def __call__(self, df):
        return self.new_dataframe(
            [df],
            shape=df.shape,
            dtypes=df.dtypes,
            index_value=df.index_value,
            columns_value=df.columns_value,
        )


class SeriesIlocSetItem(DataFrameOperator, DataFrameOperatorMixin):
    _op_module_ = "series"
    _op_type_ = opcodes.DATAFRAME_ILOC_SETITEM

    indexes = ListField("indexes", default=None)
    value = AnyField("value", default=None)

    def __init__(self, gpu=None, sparse=False, **kw):
        super().__init__(
            gpu=gpu,
            sparse=sparse,
            _output_types=[OutputType.series],
            **kw,
        )

    def __call__(self, series):
        return self.new_series(
            [series],
            shape=series.shape,
            dtype=series.dtype,
            index_value=series.index_value,
            name=series.name,
        )


def index_getitem(idx, indexes):
    op = IndexIlocGetItem(indexes=process_iloc_indexes(idx, indexes))
    return op(idx)


def index_setitem(_idx, *_):
    raise TypeError("Index does not support mutable operations")


def iloc(a):
    return DataFrameIloc(a)


def head(a, n=5):
    """
    Return the first `n` rows.

    This function returns the first `n` rows for the object based
    on position. It is useful for quickly testing if your object
    has the right type of data in it.

    For negative values of `n`, this function returns all rows except
    the last `n` rows, equivalent to ``df[:-n]``.

    Parameters
    ----------
    n : int, default 5
        Number of rows to select.

    Returns
    -------
    same type as caller
        The first `n` rows of the caller object.

    See Also
    --------
    DataFrame.tail: Returns the last `n` rows.

    Examples
    --------
    >>> import maxframe.dataframe as md
    >>> df = md.DataFrame({'animal': ['alligator', 'bee', 'falcon', 'lion',
    ...                    'monkey', 'parrot', 'shark', 'whale', 'zebra']})
    >>> df.execute()
          animal
    0  alligator
    1        bee
    2     falcon
    3       lion
    4     monkey
    5     parrot
    6      shark
    7      whale
    8      zebra

    Viewing the first 5 lines

    >>> df.head().execute()
          animal
    0  alligator
    1        bee
    2     falcon
    3       lion
    4     monkey

    Viewing the first `n` lines (three in this case)

    >>> df.head(3).execute()
          animal
    0  alligator
    1        bee
    2     falcon

    For negative values of `n`

    >>> df.head(-3).execute()
          animal
    0  alligator
    1        bee
    2     falcon
    3       lion
    4     monkey
    5     parrot
    """
    return DataFrameIloc(a)[0:n]


def tail(a, n=5):
    """
    Return the last `n` rows.

    This function returns last `n` rows from the object based on
    position. It is useful for quickly verifying data, for example,
    after sorting or appending rows.

    For negative values of `n`, this function returns all rows except
    the first `n` rows, equivalent to ``df[n:]``.

    Parameters
    ----------
    n : int, default 5
        Number of rows to select.

    Returns
    -------
    type of caller
        The last `n` rows of the caller object.

    See Also
    --------
    DataFrame.head : The first `n` rows of the caller object.

    Examples
    --------
    >>> import maxframe.dataframe as md
    >>> df = md.DataFrame({'animal': ['alligator', 'bee', 'falcon', 'lion',
    ...                    'monkey', 'parrot', 'shark', 'whale', 'zebra']})
    >>> df.execute()
          animal
    0  alligator
    1        bee
    2     falcon
    3       lion
    4     monkey
    5     parrot
    6      shark
    7      whale
    8      zebra

    Viewing the last 5 lines

    >>> df.tail().execute()
       animal
    4  monkey
    5  parrot
    6   shark
    7   whale
    8   zebra

    Viewing the last `n` lines (three in this case)

    >>> df.tail(3).execute()
      animal
    6  shark
    7  whale
    8  zebra

    For negative values of `n`

    >>> df.tail(-3).execute()
       animal
    3    lion
    4  monkey
    5  parrot
    6   shark
    7   whale
    8   zebra
    """
    return DataFrameIloc(a)[-n:]
