# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import weakref
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Type

import msgpack

from ...errors import MaxFrameDeprecationError
from ...lib.mmh3 import hash
from ...utils import no_default
from ..core import Placeholder, Serializer, buffered, load_type
from .field import Field
from .field_type import DictType, ListType, PrimitiveFieldType, TupleType

try:
    from ..deserializer import get_legacy_module_name
except ImportError:
    get_legacy_module_name = lambda x: x

logger = logging.getLogger(__name__)
_deprecate_log_key = "_SER_DEPRECATE_LOGGED"


def _is_field_primitive_compound(field: Field):
    if field.on_serialize is not None or field.on_deserialize is not None:
        return False

    def check_type(field_type):
        if isinstance(field_type, PrimitiveFieldType):
            return True
        if isinstance(field_type, (ListType, TupleType)):
            if all(
                check_type(element_type) or element_type is Ellipsis
                for element_type in field_type._field_types
            ):
                return True
        if isinstance(field_type, DictType):
            if all(
                isinstance(element_type, PrimitiveFieldType) or element_type is Ellipsis
                for element_type in (field_type.key_type, field_type.value_type)
            ):
                return True
        return False

    return check_type(field.field_type)


class SerializableMeta(type):
    def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict):
        # All the fields including misc fields.
        legacy_name_hash = hash(
            f"{get_legacy_module_name(properties.get('__module__'))}.{name}"
        )
        name_hash = hash(
            f"{properties.get('__module__')}.{properties.get('__qualname__')}"
        )
        all_fields = dict()
        # mapping field names to base classes
        field_to_cls_hash = dict()
        # mapping legacy name hash to name hashes
        legacy_to_new_name_hash = {legacy_name_hash: name_hash}

        for base in bases:
            if not hasattr(base, "_FIELDS"):
                continue
            all_fields.update(base._FIELDS)
            field_to_cls_hash.update(base._FIELD_TO_NAME_HASH)
            legacy_to_new_name_hash.update(base._LEGACY_TO_NEW_NAME_HASH)

        properties_without_fields = {}
        properties_field_slot_names = []
        for k, v in properties.items():
            if not isinstance(v, Field):
                properties_without_fields[k] = v
                continue

            field = all_fields.get(k)
            # record the field for the class being created
            field_to_cls_hash[k] = name_hash
            if field is None:
                properties_field_slot_names.append(k)
            else:
                v.name = field.name
                v.get = field.get
                v.set = field.set
                v.__delete__ = field.__delete__
            all_fields[k] = v

        # Make field order deterministic to serialize it as list instead of dict.
        field_order = list(all_fields)
        primitive_fields = []
        primitive_field_names = set()
        non_primitive_fields = []
        for field_name, v in all_fields.items():
            if _is_field_primitive_compound(v):
                primitive_fields.append(v)
                primitive_field_names.add(field_name)
            else:
                non_primitive_fields.append(v)

        # count number of fields for every base class
        cls_to_primitive_field_count = OrderedDict()
        cls_to_non_primitive_field_count = OrderedDict()
        for field_name in field_order:
            cls_hash = field_to_cls_hash[field_name]
            if field_name in primitive_field_names:
                cls_to_primitive_field_count[cls_hash] = (
                    cls_to_primitive_field_count.get(cls_hash, 0) + 1
                )
            else:
                cls_to_non_primitive_field_count[cls_hash] = (
                    cls_to_non_primitive_field_count.get(cls_hash, 0) + 1
                )

        slots = set(properties.pop("__slots__", set()))
        slots.update(properties_field_slot_names)

        properties = properties_without_fields

        # todo remove this prop when all versions below v1.0.0rc1 is eliminated
        properties["_LEGACY_NAME_HASH"] = legacy_name_hash
        properties["_NAME_HASH"] = name_hash
        properties["_LEGACY_TO_NEW_NAME_HASH"] = legacy_to_new_name_hash

        properties["_FIELDS"] = all_fields
        properties["_FIELD_ORDER"] = field_order
        properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash
        properties["_PRIMITIVE_FIELDS"] = primitive_fields
        properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = OrderedDict(
            cls_to_primitive_field_count
        )
        properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields
        properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = OrderedDict(
            cls_to_non_primitive_field_count
        )
        properties["__slots__"] = tuple(slots)

        clz = type.__new__(mcs, name, bases, properties)
        # Bind slot member_descriptor with field.
        for name in properties_field_slot_names:
            member_descriptor = getattr(clz, name)
            field = all_fields[name]
            field.name = member_descriptor.__name__
            field.get = member_descriptor.__get__
            field.set = member_descriptor.__set__
            field.__delete__ = member_descriptor.__delete__
            setattr(clz, name, field)

        return clz


class Serializable(metaclass=SerializableMeta):
    __slots__ = ("__weakref__",)

    _cache_primitive_serial = False
    _ignore_non_existing_keys = False

    _LEGACY_NAME_HASH: int
    _NAME_HASH: int
    _LEGACY_TO_NEW_NAME_HASH: Dict[int, int]

    _FIELDS: Dict[str, Field]
    _FIELD_ORDER: List[str]
    _FIELD_TO_NAME_HASH: Dict[str, int]
    _PRIMITIVE_FIELDS: List[str]
    _CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int]
    _NON_PRIMITIVE_FIELDS: List[str]
    _CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int]

    def __init__(self, *args, **kwargs):
        fields = self._FIELDS
        field_order = self._FIELD_ORDER
        assert len(args) <= len(field_order)
        if args:  # pragma: no cover
            values = dict(zip(field_order, args))
            values.update(kwargs)
        else:
            values = kwargs
        for k, v in values.items():
            try:
                fields[k].set(self, v)
            except KeyError:
                if not self._ignore_non_existing_keys:
                    raise

    def __on_deserialize__(self):
        pass

    def __repr__(self):
        values = ", ".join(
            [
                "{}={!r}".format(slot, getattr(self, slot, None))
                for slot in self.__slots__
            ]
        )
        return "{}({})".format(self.__class__.__name__, values)

    def copy_to(self, target: "Serializable") -> "Serializable":
        copied_fields = target._FIELDS
        for k, field in self._FIELDS.items():
            try:
                # Slightly faster than getattr.
                value = field.get(self, k)
                try:
                    copied_fields[k].set(target, value)
                except KeyError:
                    copied_fields["_" + k].set(target, value)
            except AttributeError:
                continue
        return target

    def copy(self) -> "Serializable":
        return self.copy_to(type(self)())


_primitive_serial_cache = weakref.WeakKeyDictionary()


class _NoFieldValue:
    pass


_no_field_value = _NoFieldValue()


def _to_primitive_placeholder(v: Any) -> Any:
    if v is _no_field_value or v is no_default:
        return {}
    return v


def _restore_primitive_placeholder(v: Any) -> Any:
    if type(v) is dict:
        if v == {}:
            return _no_field_value
        else:
            return v
    else:
        return v


class SerializableSerializer(Serializer):
    """
    Leverage DictSerializer to perform serde.
    """

    @classmethod
    def _log_legacy(cls, context: Dict, key: Any, msg: str, *args, **kwargs):
        level = kwargs.pop("level", logging.WARNING)
        try:
            logged_keys = context[_deprecate_log_key]
        except KeyError:
            logged_keys = context[_deprecate_log_key] = set()
        if key not in logged_keys:
            logged_keys.add(key)
            logger.log(level, msg, *args, **kwargs)

    @classmethod
    def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False):
        return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}"

    @classmethod
    def _get_field_values(cls, obj: Serializable, fields):
        values = []
        for field in fields:
            try:
                value = field.get(obj)
                if field.on_serialize is not None:
                    value = field.on_serialize(value)
            except AttributeError:
                # Most field values are not None, serialize by list is more efficient than dict.
                value = _no_field_value
            values.append(value)
        return values

    @buffered
    def serial(self, obj: Serializable, context: Dict):
        if obj._cache_primitive_serial and obj in _primitive_serial_cache:
            primitive_vals = _primitive_serial_cache[obj]
        else:
            primitive_vals = self._get_field_values(obj, obj._PRIMITIVE_FIELDS)
            # replace _no_field_value as {} to make them msgpack-serializable
            primitive_vals = [_to_primitive_placeholder(v) for v in primitive_vals]
            if obj._cache_primitive_serial:
                primitive_vals = msgpack.dumps(primitive_vals)
                _primitive_serial_cache[obj] = primitive_vals

        compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS)
        cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}"

        field_count_key = self._get_obj_field_count_key(obj)
        if not self.is_public_data_exist(context, field_count_key):
            # store field distribution for current Serializable
            counts = [
                list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()),
                list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()),
            ]
            field_count_data = msgpack.dumps(counts)
            self.put_public_data(
                context, self._get_obj_field_count_key(obj), field_count_data
            )
        return [cls_module, primitive_vals], [compound_vals], False

    @staticmethod
    def _set_field_value(obj: Serializable, field: Field, value):
        if value is _no_field_value:
            return
        if type(value) is Placeholder:
            if field.on_deserialize is not None:
                value.callbacks.append(
                    lambda v: field.set(obj, field.on_deserialize(v))
                )
            else:
                value.callbacks.append(lambda v: field.set(obj, v))
        else:
            if field.on_deserialize is not None:
                field.set(obj, field.on_deserialize(value))
            else:
                field.set(obj, value)

    @classmethod
    def _prune_server_fields(
        cls,
        client_cls_to_field_count: Optional[Dict[int, int]],
        server_cls_to_field_count: Dict[int, int],
        server_fields: list,
        legacy_to_new_hash: Dict[int, int],
    ) -> list:
        if set(client_cls_to_field_count.keys()) == set(
            server_cls_to_field_count.keys()
        ):
            return server_fields

        new_to_legacy_hash = {v: k for k, v in legacy_to_new_hash.items()}
        ret_server_fields = []
        server_pos = 0
        for cls_hash, count in server_cls_to_field_count.items():
            if (
                cls_hash in client_cls_to_field_count
                or new_to_legacy_hash.get(cls_hash) in client_cls_to_field_count
            ):
                ret_server_fields.extend(server_fields[server_pos : server_pos + count])
            server_pos += count
        return ret_server_fields

    @classmethod
    def _set_field_values(
        cls,
        obj: Serializable,
        values: List[Any],
        client_cls_to_field_count: Optional[Dict[int, int]],
        is_primitive: bool = True,
    ):
        obj_class = type(obj)
        legacy_to_new_hash = obj_class._LEGACY_TO_NEW_NAME_HASH

        if is_primitive:
            server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT
            field_def_list = obj_class._PRIMITIVE_FIELDS
        else:
            server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT
            field_def_list = obj_class._NON_PRIMITIVE_FIELDS

        server_fields = cls._prune_server_fields(
            client_cls_to_field_count,
            server_cls_to_field_count,
            field_def_list,
            legacy_to_new_hash,
        )

        field_num, server_field_num = 0, 0
        for cls_hash, count in client_cls_to_field_count.items():
            # cut values and fields given field distribution
            # at client and server end
            cls_fields = server_fields[server_field_num : field_num + count]
            cls_values = values[field_num : field_num + count]
            for field, value in zip(cls_fields, cls_values):
                if is_primitive:
                    value = _restore_primitive_placeholder(value)
                if not is_primitive or value is not _no_field_value:
                    cls._set_field_value(obj, field, value)
            field_num += count
            try:
                server_field_num += server_cls_to_field_count[cls_hash]
            except KeyError:
                try:
                    server_field_num += server_cls_to_field_count[
                        legacy_to_new_hash[cls_hash]
                    ]
                except KeyError:
                    # it is possible that certain type of field does not exist
                    #  at server side
                    pass

    def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable:
        obj_class_name, primitives = serialized
        obj_class = load_type(obj_class_name, Serializable)

        if type(primitives) is not list:
            primitives = msgpack.loads(primitives)

        obj = obj_class.__new__(obj_class)

        field_count_data = self.get_public_data(
            context, self._get_obj_field_count_key(obj)
        )
        if field_count_data is None:
            # try using legacy field count key to get counts
            field_count_data = self.get_public_data(
                context, self._get_obj_field_count_key(obj, legacy=True)
            )

            if field_count_data is None:
                self._log_legacy(
                    context,
                    ("MISSING_CLASS", obj_class_name),
                    "Field count info of %s not found in serialized data",
                    obj_class_name,
                    level=logging.ERROR,
                )
                raise MaxFrameDeprecationError(
                    "Failed to deserialize request. Please upgrade your "
                    "MaxFrame client to the latest release."
                )
            else:
                self._log_legacy(
                    context,
                    ("LEGACY_CLASS", obj_class_name),
                    "Class %s used in legacy client",
                    obj_class_name,
                )

        cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data)
        cls_to_prim_key = dict(cls_to_prim_key)
        cls_to_non_prim_key = dict(cls_to_non_prim_key)

        if primitives:
            self._set_field_values(obj, primitives, cls_to_prim_key, True)
        if obj_class._NON_PRIMITIVE_FIELDS:
            self._set_field_values(obj, subs[0], cls_to_non_prim_key, False)
        obj.__on_deserialize__()
        return obj


class NoFieldValueSerializer(Serializer):
    def serial(self, obj, context):
        return [], [], True

    def deserial(self, serialized, context, subs):
        return _no_field_value


SerializableSerializer.register(Serializable)
NoFieldValueSerializer.register(_NoFieldValue)
