# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections.abc import Iterable
from contextlib import contextmanager

import numpy as np

from ...serialization.serializables import FieldTypes, Int32Field, TupleField
from ..core import TENSOR_TYPE
from ..datasource import tensor as astensor
from ..misc import broadcast_to
from ..operators import TensorMapReduceOperator, TensorOperator, TensorOperatorMixin
from ..utils import broadcast_shape


class RandomState:
    def __init__(self, seed=None):
        self._random_state = np.random.RandomState(seed=seed)

    def seed(self, seed=None):
        """
        Seed the generator.

        This method is called when `RandomState` is initialized. It can be
        called again to re-seed the generator. For details, see `RandomState`.

        Parameters
        ----------
        seed : int or 1-d array_like, optional
            Seed for `RandomState`.
            Must be convertible to 32 bit unsigned integers.

        See Also
        --------
        RandomState
        """
        self._random_state.seed(seed=seed)

    def to_numpy(self):
        return self._random_state

    @classmethod
    def from_numpy(cls, np_random_state):
        state = RandomState()
        state._random_state = np_random_state
        return state

    @classmethod
    def _handle_size(cls, size):
        if size is None:
            return size
        try:
            return tuple(int(s) for s in size)
        except TypeError:
            return (size,)


_random_state = RandomState()


def handle_array(arg):
    if not isinstance(arg, TENSOR_TYPE):
        if not isinstance(arg, Iterable):
            return arg

        arg = np.asarray(arg)
        return arg[(0,) * max(1, arg.ndim)]
    elif hasattr(arg, "op") and hasattr(arg.op, "data"):
        return arg.op.data[(0,) * max(1, arg.ndim)]

    return np.empty((0,), dtype=arg.dtype)


class TensorRandomOperatorMixin(TensorOperatorMixin):
    __slots__ = ()

    def _calc_shape(self, shapes):
        shapes = list(shapes)
        if getattr(self, "size", None) is not None:
            shapes.append(getattr(self, "size"))
        return broadcast_shape(*shapes)

    @classmethod
    def _handle_arg(cls, arg, chunk_size):
        if isinstance(arg, (list, np.ndarray)):
            arg = astensor(arg, chunk_size=chunk_size)

        return arg

    @contextmanager
    def _get_inputs_shape_by_given_fields(
        self, inputs, shape, raw_chunk_size=None, tensor=True
    ):
        fields = getattr(self, "_input_fields_", [])
        to_one_chunk_fields = set(getattr(self, "_into_one_chunk_fields_", list()))

        field_to_obj = dict()
        to_broadcast_shapes = []
        if fields:
            if getattr(self, fields[0], None) is None:
                # create from beginning
                for field, val in zip(fields, inputs):
                    if field not in to_one_chunk_fields:
                        if isinstance(val, list):
                            val = np.asarray(val)
                        if tensor:
                            val = self._handle_arg(val, raw_chunk_size)
                    if isinstance(val, TENSOR_TYPE):
                        field_to_obj[field] = val
                        if field not in to_one_chunk_fields:
                            to_broadcast_shapes.append(val.shape)
                    setattr(self, field, val)
            else:
                inputs_iter = iter(inputs)
                for field in fields:
                    if isinstance(getattr(self, field), TENSOR_TYPE):
                        field_to_obj[field] = next(inputs_iter)

        if tensor:
            if shape is None:
                shape = self._calc_shape(to_broadcast_shapes)

            for field, inp in field_to_obj.items():
                if field not in to_one_chunk_fields:
                    field_to_obj[field] = broadcast_to(inp, shape)

        yield [field_to_obj[f] for f in fields if f in field_to_obj], shape

        inputs_iter = iter(getattr(self, "_inputs"))
        for field in fields:
            if field in field_to_obj:
                setattr(self, field, next(inputs_iter))

    @classmethod
    def _get_shape(cls, kws, kw):
        if kw.get("shape") is not None:
            return kw.get("shape")
        elif kws is not None and len(kws) > 0:
            return kws[0].get("shape")

    def _new_tileables(self, inputs, kws=None, **kw):
        raw_chunk_size = kw.get("chunk_size", None)
        shape = self._get_shape(kws, kw)
        with self._get_inputs_shape_by_given_fields(
            inputs, shape, raw_chunk_size, True
        ) as (inputs, shape):
            kw["shape"] = shape
            return super()._new_tileables(inputs, kws=kws, **kw)


def _on_serialize_random_state(rs):
    return rs.get_state() if rs is not None else None


def _on_deserialize_random_state(tup):
    if tup is None:
        return None

    rs = np.random.RandomState()
    rs.set_state(tup)
    return rs


def RandomStateField(name, **kwargs):
    kwargs.update(
        dict(
            on_serialize=_on_serialize_random_state,
            on_deserialize=_on_deserialize_random_state,
        )
    )
    return TupleField(name, **kwargs)


class TensorSeedOperatorMixin(object):
    @property
    def seed(self):
        return getattr(self, "seed", None)

    @property
    def args(self):
        if hasattr(self, "_fields_"):
            return self._fields_
        else:
            return [
                field
                for field in self._FIELDS
                if field not in TensorRandomOperator._FIELDS
            ]


class TensorRandomOperator(TensorSeedOperatorMixin, TensorOperator):
    seed = Int32Field("seed")

    def __init__(self, dtype=None, **kw):
        dtype = np.dtype(dtype) if dtype is not None else dtype
        if "state" in kw:
            kw["_state"] = kw.pop("state")
        super().__init__(dtype=dtype, **kw)


class TensorRandomMapReduceOperator(TensorSeedOperatorMixin, TensorMapReduceOperator):
    seed = Int32Field("seed")

    def __init__(self, dtype=None, **kw):
        dtype = np.dtype(dtype) if dtype is not None else dtype
        if "state" in kw:
            kw["_state"] = kw.pop("state")
        super().__init__(dtype=dtype, **kw)


class TensorDistribution(TensorRandomOperator):
    size = TupleField("size", FieldTypes.int64)


class TensorSimpleRandomData(TensorRandomOperator):
    size = TupleField("size", FieldTypes.int64)

    def __init__(self, size=None, **kw):
        if type(size) is int:
            size = (size,)
        super().__init__(size=size, **kw)
