# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd

from ... import opcodes
from ...core import ENTITY_TYPE
from ...serialization.serializables import AnyField, BoolField, Int32Field, StringField
from ...tensor.utils import filter_inputs
from ..core import DATAFRAME_TYPE, SERIES_TYPE
from ..operators import DataFrameOperator, DataFrameOperatorMixin
from ..utils import build_df, build_series, validate_axis


class DataFrameWhere(DataFrameOperator, DataFrameOperatorMixin):
    _op_type_ = opcodes.WHERE

    _input = AnyField("input")
    cond = AnyField("cond", default=None)
    other = AnyField("other", default=None)
    axis = Int32Field("axis", default=None)
    level = AnyField("level", default=None)
    errors = StringField("errors", default=None)
    try_cast = BoolField("try_cast", default=None)
    replace_true = BoolField("replace_true", default=None)

    def __init__(self, input=None, **kw):
        super().__init__(_input=input, **kw)

    @property
    def input(self):
        return self._input

    def __call__(self, df_or_series):
        def _check_input_index(obj, axis=None):
            axis = axis if axis is not None else self.axis
            if isinstance(obj, DATAFRAME_TYPE) and (
                df_or_series.columns_value.key != obj.columns_value.key
                or df_or_series.index_value.key != obj.index_value.key
            ):
                raise NotImplementedError("Aligning different indices not supported")
            elif (
                isinstance(obj, SERIES_TYPE)
                and df_or_series.axes[axis].index_value.key != obj.index_value.key
            ):
                raise NotImplementedError("Aligning different indices not supported")

        _check_input_index(self.cond, axis=0)
        _check_input_index(self.other)

        if isinstance(df_or_series, DATAFRAME_TYPE):
            mock_obj = build_df(df_or_series)
        else:
            mock_obj = build_series(df_or_series)

        if isinstance(self.other, (pd.DataFrame, DATAFRAME_TYPE)):
            mock_other = build_df(self.other)
        elif isinstance(self.other, (pd.Series, SERIES_TYPE)):
            mock_other = build_series(self.other)
        else:
            mock_other = self.other

        result_df = mock_obj.where(
            np.zeros(mock_obj.shape).astype(bool),
            other=mock_other,
            axis=self.axis,
            level=self.level,
            errors=self.errors,
            try_cast=self.try_cast,
        )

        inputs = filter_inputs([df_or_series, self.cond, self.other])
        if isinstance(df_or_series, DATAFRAME_TYPE):
            return self.new_dataframe(
                inputs,
                shape=df_or_series.shape,
                dtypes=result_df.dtypes,
                index_value=df_or_series.index_value,
                columns_value=df_or_series.columns_value,
            )
        else:
            return self.new_series(
                inputs,
                shape=df_or_series.shape,
                name=df_or_series.name,
                dtype=result_df.dtype,
                index_value=df_or_series.index_value,
            )

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        inputs_iter = iter(self._inputs)
        self._input = next(inputs_iter)
        if isinstance(self.cond, ENTITY_TYPE):
            self.cond = next(inputs_iter)
        if isinstance(self.other, ENTITY_TYPE):
            self.other = next(inputs_iter)


_doc_template = """
Replace values where the condition is {replace_true}.

Parameters
----------
cond : bool Series/DataFrame, array-like, or callable
    Where `cond` is False, keep the original value. Where
    True, replace with corresponding value from `other`.
    If `cond` is callable, it is computed on the Series/DataFrame and
    should return boolean Series/DataFrame or array. The callable must
    not change input Series/DataFrame (though pandas doesn't check it).
other : scalar, Series/DataFrame, or callable
    Entries where `cond` is True are replaced with
    corresponding value from `other`.
    If other is callable, it is computed on the Series/DataFrame and
    should return scalar or Series/DataFrame. The callable must not
    change input Series/DataFrame (though pandas doesn't check it).
inplace : bool, default False
    Whether to perform the operation in place on the data.
axis : int, default None
    Alignment axis if needed.
level : int, default None
    Alignment level if needed.
errors : str, {{'raise', 'ignore'}}, default 'raise'
    Note that currently this parameter won't affect
    the results and will always coerce to a suitable dtype.

    - 'raise' : allow exceptions to be raised.
    - 'ignore' : suppress exceptions. On error return original object.

try_cast : bool, default False
    Try to cast the result back to the input type (if possible).

Returns
-------
Same type as caller

See Also
--------
:func:`DataFrame.{opposite}` : Return an object of same shape as
    self.

Notes
-----
The mask method is an application of the if-then idiom. For each
element in the calling DataFrame, if ``cond`` is ``False`` the
element is used; otherwise the corresponding element from the DataFrame
``other`` is used.

The signature for :func:`DataFrame.where` differs from
:func:`numpy.where`. Roughly ``df1.where(m, df2)`` is equivalent to
``np.where(m, df1, df2)``.

For further details and examples see the ``mask`` documentation in
:ref:`indexing <indexing.where_mask>`.

Examples
--------
>>> import maxframe.tensor as mt
>>> import maxframe.dataframe as md
>>> s = md.Series(range(5))
>>> s.where(s > 0).execute()
0    NaN
1    1.0
2    2.0
3    3.0
4    4.0
dtype: float64

>>> s.mask(s > 0).execute()
0    0.0
1    NaN
2    NaN
3    NaN
4    NaN
dtype: float64

>>> s.where(s > 1, 10).execute()
0    10
1    10
2    2
3    3
4    4
dtype: int64

>>> df = md.DataFrame(mt.arange(10).reshape(-1, 2), columns=['A', 'B'])
>>> df.execute()
   A  B
0  0  1
1  2  3
2  4  5
3  6  7
4  8  9
>>> m = df % 3 == 0
>>> df.where(m, -df).execute()
   A  B
0  0 -1
1 -2  3
2 -4 -5
3  6 -7
4 -8  9
>>> df.where(m, -df) == mt.where(m, df, -df).execute()
      A     B
0  True  True
1  True  True
2  True  True
3  True  True
4  True  True
>>> df.where(m, -df) == df.mask(~m, -df).execute()
      A     B
0  True  True
1  True  True
2  True  True
3  True  True
4  True  True
"""


def _where(
    df_or_series,
    cond,
    other=np.nan,
    inplace=False,
    axis=None,
    level=None,
    errors="raise",
    try_cast=False,
    replace_true=False,
):
    if df_or_series.ndim == 2 and getattr(other, "ndim", 2) == 1 and axis is None:
        raise ValueError("Must specify axis=0 or 1")

    axis = validate_axis(axis or 0, df_or_series)
    op = DataFrameWhere(
        cond=cond,
        other=other,
        axis=axis,
        level=level,
        errors=errors,
        try_cast=try_cast,
        replace_true=replace_true,
    )
    result = op(df_or_series)
    if inplace:
        df_or_series.data = result.data
    else:
        return result


def where(
    df_or_series,
    cond,
    other=np.nan,
    inplace=False,
    axis=None,
    level=None,
    errors="raise",
    try_cast=False,
):
    return _where(
        df_or_series,
        cond,
        other=other,
        inplace=inplace,
        axis=axis,
        level=level,
        errors=errors,
        try_cast=try_cast,
        replace_true=False,
    )


def mask(
    df_or_series,
    cond,
    other=np.nan,
    inplace=False,
    axis=None,
    level=None,
    errors="raise",
    try_cast=False,
):
    return _where(
        df_or_series,
        cond,
        other=other,
        inplace=inplace,
        axis=axis,
        level=level,
        errors=errors,
        try_cast=try_cast,
        replace_true=True,
    )


mask.__doc__ = _doc_template.format(replace_true=True, opposite="where")
where.__doc__ = _doc_template.format(replace_true=False, opposite="mask")
