# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

import numpy as np
import pandas as pd

from ... import opcodes
from ...core import OutputType
from ...serialization.serializables import (
    BoolField,
    DictField,
    FunctionField,
    TupleField,
)
from ..core import DataFrame
from ..operators import DataFrameOperator, DataFrameOperatorMixin
from ..utils import (
    copy_func_scheduling_hints,
    gen_unknown_index_value,
    make_dtypes,
    parse_index,
)


class DataFrameFlatMapOperator(DataFrameOperator, DataFrameOperatorMixin):
    _op_type_ = opcodes.FLATMAP

    func = FunctionField("func")
    raw = BoolField("raw", default=False)
    args = TupleField("args", default=())
    kwargs = DictField("kwargs", default={})

    def __init__(self, output_types=None, **kw):
        super().__init__(_output_types=output_types, **kw)
        if hasattr(self, "func"):
            copy_func_scheduling_hints(self.func, self)

    def _call_dataframe(self, df: DataFrame, dtypes: pd.Series):
        dtypes = make_dtypes(dtypes)
        index_value = gen_unknown_index_value(
            df.index_value,
            (df.key, df.index_value.key, self.func),
            normalize_range_index=True,
        )
        return self.new_dataframe(
            [df],
            shape=(np.nan, len(dtypes)),
            index_value=index_value,
            columns_value=parse_index(dtypes.index, store_data=True),
            dtypes=dtypes,
        )

    def _call_series_or_index(self, series, dtypes=None):
        index_value = gen_unknown_index_value(
            series.index_value,
            (series.key, series.index_value.key, self.func),
            normalize_range_index=True,
        )

        if self.output_types[0] == OutputType.series:
            name, dtype = dtypes
            return self.new_series(
                [series],
                dtype=dtype,
                shape=(np.nan,),
                index_value=index_value,
                name=name,
            )

        dtypes = make_dtypes(dtypes)
        columns_value = parse_index(dtypes.index, store_data=True)
        return self.new_dataframe(
            [series],
            shape=(np.nan, len(dtypes)),
            index_value=index_value,
            columns_value=columns_value,
            dtypes=dtypes,
        )

    def __call__(
        self,
        df_or_series,
        dtypes=None,
        output_type=None,
    ):
        if df_or_series.op.output_types[0] == OutputType.dataframe:
            return self._call_dataframe(df_or_series, dtypes=dtypes)
        else:
            return self._call_series_or_index(df_or_series, dtypes=dtypes)


def df_flatmap(dataframe, func: Callable, dtypes=None, raw=False, args=(), **kwargs):
    """
    Apply the given function to each row and then flatten results. Use this method if your transformation returns
    multiple rows for each input row.

    This function applies a transformation to each row of the DataFrame, where the transformation can return zero
    or multiple values, effectively flattening Python generators, list-like collections, and DataFrames.

    Parameters
    ----------
    func : Callable
        Function to apply to each row of the DataFrame. It should accept a Series (or an array if `raw=True`)
        representing a row and return a list or iterable of values.

    dtypes : Series, dict or list
        Specify dtypes of returned DataFrame.

    raw : bool, default False
        Determines if the row is passed as a Series or as a numpy array:

        * ``False`` : passes each row as a Series to the function.
        * ``True`` : the passed function will receive numpy array objects instead.

    args : tuple
        Positional arguments to pass to `func`.

    **kwargs
        Additional keyword arguments to pass as keywords arguments to `func`.

    Returns
    -------
    DataFrame
        Return DataFrame with specified `dtypes`.

    Notes
    -----
    The ``func`` must return an iterable of values for each input row. The index of the resulting DataFrame will be
    repeated based on the number of output rows generated by `func`.

    Examples
    --------
    >>> import numpy as np
    >>> import maxframe.dataframe as md
    >>> df = md.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
    >>> df.execute()
       A  B
    0  1  4
    1  2  5
    2  3  6

    Define a function that takes a number and returns a list of two numbers:

    >>> def generate_values_array(row):
    ...     return [row['A'] * 2, row['B'] * 3]

    Define a function that takes a row and return two rows and two columns:

    >>> def generate_values_in_generator(row):
    ...     yield [row[0] * 2, row[1] * 4]
    ...     yield [row[0] * 3, row[1] * 5]

    Which equals to the following function return a dataframe:

    >>> def generate_values_in_dataframe(row):
    ...     return pd.DataFrame([[row[0] * 2, row[1] * 4], [row[0] * 3, row[1] * 5]])

    Specify `dtypes` with a function which returns a DataFrame:

    >>> df.mf.flatmap(generate_values_array, dtypes=pd.Series({'A': 'int'})).execute()
            A
        0   2
        0  12
        1   4
        1  15
        2   6
        2  18

    Specify raw=True to pass input row as array:

    >>> df.mf.flatmap(generate_values_in_generator, dtypes={"A": "int", "B": "int"}, raw=True).execute()
           A   B
        0  2  16
        0  3  20
        1  4  20
        1  6  25
        2  6  24
        2  9  30
    """
    if dtypes is None or len(dtypes) == 0:
        raise TypeError(
            "Cannot determine {dtypes} by calculating with enumerate data, "
            "please specify it as arguments"
        )

    if not isinstance(func, Callable):
        raise TypeError("function must be a callable object")

    output_types = [OutputType.dataframe]
    op = DataFrameFlatMapOperator(
        func=func, raw=raw, output_types=output_types, args=args, kwargs=kwargs
    )
    return op(
        dataframe,
        dtypes=dtypes,
    )


def series_flatmap(
    series, func: Callable, dtypes=None, dtype=None, name=None, args=(), **kwargs
):
    """
    Apply the given function to each row and then flatten results. Use this method if your transformation returns
    multiple rows for each input row.

    This function applies a transformation to each element of the Series, where the transformation can return zero
     or multiple values, effectively flattening Python generator, list-liked collections and DataFrame.

    Parameters
    ----------
    func : Callable
        Function to apply to each element of the Series. It should accept a scalar value
        (or an array if ``raw=True``) and return a list or iterable of values.

    dtypes : Series, default None
        Specify dtypes of returned DataFrame. Can't work with dtype.

    dtype : numpy.dtype, default None
        Specify dtype of returned Series. Can't work with dtypes.

    name : str, default None
        Specify name of the returned Series.

    args : tuple
        Positional arguments to pass to ``func``.

    **kwargs
        Additional keyword arguments to pass as keywords arguments to ``func``.

    Returns
    -------
    DataFrame or Series
        Result of DataFrame when dtypes specified, else Series.

    Notes
    -----
    The ``func`` must return an iterable of values for each input element. If ``dtypes`` is specified,
    `flatmap` will return a DataFrame, if ``dtype`` and ``name`` is specified, a Series will be returned.

    The index of the resulting DataFrame/Series will be repeated based on the number of output rows generated
    by ``func``.

    Examples
    --------
    >>> import numpy as np
    >>> import maxframe.dataframe as md
    >>> df = md.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
    >>> df.execute()
       A  B
    0  1  4
    1  2  5
    2  3  6

    Define a function that takes a number and returns a list of two numbers:

    >>> def generate_values_array(x):
    ...     return [x * 2, x * 3]

    Specify ``dtype`` with a function which returns list to return more elements as a Series:

    >>> df['A'].mf.flatmap(generate_values_array, dtype="int", name="C").execute()
        0    2
        0    3
        1    4
        1    6
        2    6
        2    9
        Name: C, dtype: int64

    Specify ``dtypes`` to return multi columns as a DataFrame:


    >>> def generate_values_in_generator(x):
    ...     yield pd.Series([x * 2, x * 4])
    ...     yield pd.Series([x * 3, x * 5])

    >>> df['A'].mf.flatmap(generate_values_in_generator, dtypes={"A": "int", "B": "int"}).execute()
           A   B
        0  2   4
        0  3   5
        1  4   8
        1  6  10
        2  6  12
        2  9  15
    """

    if dtypes is not None and dtype is not None:
        raise ValueError("Both dtypes and dtype cannot be specified at the same time.")

    dtypes = (name, dtype) if dtype is not None else dtypes
    if dtypes is None:
        raise TypeError(
            "Cannot determine {dtypes} or {dtype} by calculating with enumerate data, "
            "please specify it as arguments"
        )

    if not isinstance(func, Callable):
        raise TypeError("function must be a callable object")

    output_type = OutputType.series if dtype is not None else OutputType.dataframe

    op = DataFrameFlatMapOperator(
        func=func, raw=False, output_types=[output_type], args=args, kwargs=kwargs
    )
    return op(
        series,
        dtypes=dtypes,
    )
