# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import array
import dataclasses
import datetime
import enum
import functools
import logging
from typing import TypeVar, Union
from enum import Enum

from pyfury._serialization import ENABLE_FURY_CYTHON_SERIALIZATION
from pyfury import Language
from pyfury.error import TypeUnregisteredError

from pyfury.serializer import (
    Serializer,
    Numpy1DArraySerializer,
    NDArraySerializer,
    PyArraySerializer,
    DynamicPyArraySerializer,
    _PickleStub,
    PickleStrongCacheStub,
    PickleCacheStub,
    NoneSerializer,
    BooleanSerializer,
    ByteSerializer,
    Int16Serializer,
    Int32Serializer,
    Int64Serializer,
    Float32Serializer,
    Float64Serializer,
    StringSerializer,
    DateSerializer,
    TimestampSerializer,
    BytesSerializer,
    ListSerializer,
    TupleSerializer,
    MapSerializer,
    SetSerializer,
    EnumSerializer,
    SliceSerializer,
    PickleCacheSerializer,
    PickleStrongCacheSerializer,
    PickleSerializer,
    DataClassSerializer,
)
from pyfury._struct import ComplexObjectSerializer
from pyfury.meta.metastring import MetaStringEncoder, MetaStringDecoder
from pyfury.type import (
    TypeId,
    Int8Type,
    Int16Type,
    Int32Type,
    Int64Type,
    Float32Type,
    Float64Type,
    load_class,
)
from pyfury._fury import (
    DYNAMIC_TYPE_ID,
    # preserve 0 as flag for class id not set in ClassInfo`
    NO_CLASS_ID,
)

try:
    import numpy as np
except ImportError:
    np = None

logger = logging.getLogger(__name__)


if ENABLE_FURY_CYTHON_SERIALIZATION:
    from pyfury._serialization import ClassInfo
else:

    class ClassInfo:
        __slots__ = (
            "cls",
            "type_id",
            "serializer",
            "namespace_bytes",
            "typename_bytes",
            "dynamic_type",
        )

        def __init__(
            self,
            cls: type = None,
            type_id: int = NO_CLASS_ID,
            serializer: Serializer = None,
            namespace_bytes=None,
            typename_bytes=None,
            dynamic_type: bool = False,
        ):
            self.cls = cls
            self.type_id = type_id
            self.serializer = serializer
            self.namespace_bytes = namespace_bytes
            self.typename_bytes = typename_bytes
            self.dynamic_type = dynamic_type

        def __repr__(self):
            return (
                f"ClassInfo(cls={self.cls}, type_id={self.type_id}, "
                f"serializer={self.serializer})"
            )


class ClassResolver:
    __slots__ = (
        "fury",
        "_metastr_to_str",
        "_type_id_counter",
        "_classes_info",
        "_hash_to_metastring",
        "_metastr_to_class",
        "_hash_to_classinfo",
        "_dynamic_id_to_classinfo_list",
        "_dynamic_id_to_metastr_list",
        "_dynamic_write_string_id",
        "_dynamic_written_metastr",
        "_ns_type_to_classinfo",
        "_named_type_to_classinfo",
        "namespace_encoder",
        "namespace_decoder",
        "typename_encoder",
        "typename_decoder",
        "require_registration",
        "metastring_resolver",
        "language",
        "_type_id_to_classinfo",
    )

    def __init__(self, fury):
        self.fury = fury
        self.metastring_resolver = fury.metastring_resolver
        self.language = fury.language
        self.require_registration = fury.require_class_registration
        self._metastr_to_str = dict()
        self._metastr_to_class = dict()
        self._hash_to_metastring = dict()
        self._hash_to_classinfo = dict()
        self._dynamic_written_metastr = []
        self._type_id_to_classinfo = dict()
        self._type_id_counter = 64
        self._dynamic_write_string_id = 0
        # hold objects to avoid gc, since `flat_hash_map/vector` doesn't
        # hold python reference.
        self._classes_info = dict()
        self._ns_type_to_classinfo = dict()
        self._named_type_to_classinfo = dict()
        self.namespace_encoder = MetaStringEncoder(".", "_")
        self.namespace_decoder = MetaStringDecoder(".", "_")
        self.typename_encoder = MetaStringEncoder("$", "_")
        self.typename_decoder = MetaStringDecoder("$", "_")

    def initialize(self):
        self._initialize_xlang()
        if self.fury.language == Language.PYTHON:
            self._initialize_py()

    def _initialize_py(self):
        register = functools.partial(self._register_type, internal=True)
        register(
            _PickleStub,
            type_id=PickleSerializer.PICKLE_CLASS_ID,
            serializer=PickleSerializer,
        )
        register(
            PickleStrongCacheStub,
            type_id=97,
            serializer=PickleStrongCacheSerializer(self.fury),
        )
        register(
            PickleCacheStub,
            type_id=98,
            serializer=PickleCacheSerializer(self.fury),
        )
        register(type(None), serializer=NoneSerializer)
        register(tuple, serializer=TupleSerializer)
        register(slice, serializer=SliceSerializer)

    def _initialize_xlang(self):
        register = functools.partial(self._register_type, internal=True)
        register(None, type_id=TypeId.NA, serializer=NoneSerializer)
        register(bool, type_id=TypeId.BOOL, serializer=BooleanSerializer)
        register(Int8Type, type_id=TypeId.INT8, serializer=ByteSerializer)
        register(Int16Type, type_id=TypeId.INT16, serializer=Int16Serializer)
        register(Int32Type, type_id=TypeId.INT32, serializer=Int32Serializer)
        register(Int64Type, type_id=TypeId.INT64, serializer=Int64Serializer)
        register(int, type_id=TypeId.INT64, serializer=Int64Serializer)
        register(
            Float32Type,
            type_id=TypeId.FLOAT32,
            serializer=Float32Serializer,
        )
        register(
            Float64Type,
            type_id=TypeId.FLOAT64,
            serializer=Float64Serializer,
        )
        register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer)
        register(str, type_id=TypeId.STRING, serializer=StringSerializer)
        # TODO(chaokunyang) DURATION DECIMAL
        register(
            datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer
        )
        register(datetime.date, type_id=TypeId.LOCAL_DATE, serializer=DateSerializer)
        register(bytes, type_id=TypeId.BINARY, serializer=BytesSerializer)
        for itemsize, ftype, typeid in PyArraySerializer.typecode_dict.values():
            register(
                ftype,
                type_id=typeid,
                serializer=PyArraySerializer(self.fury, ftype, typeid),
            )
        register(
            array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer
        )
        if np:
            # overwrite pyarray  with same type id.
            # if pyarray are needed, one must annotate that value with XXXArrayType
            # as a field of a struct.
            for dtype, (
                itemsize,
                format,
                ftype,
                typeid,
            ) in Numpy1DArraySerializer.dtypes_dict.items():
                register(
                    ftype,
                    type_id=typeid,
                    serializer=Numpy1DArraySerializer(self.fury, ftype, dtype),
                )
            register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer)
        register(list, type_id=TypeId.LIST, serializer=ListSerializer)
        register(set, type_id=TypeId.SET, serializer=SetSerializer)
        register(dict, type_id=TypeId.MAP, serializer=MapSerializer)
        try:
            import pyarrow as pa
            from pyfury.format.serializer import (
                ArrowRecordBatchSerializer,
                ArrowTableSerializer,
            )

            register(
                pa.RecordBatch,
                type_id=TypeId.ARROW_RECORD_BATCH,
                serializer=ArrowRecordBatchSerializer,
            )
            register(
                pa.Table, type_id=TypeId.ARROW_TABLE, serializer=ArrowTableSerializer
            )
        except Exception:
            pass

    def register_type(
        self,
        cls: Union[type, TypeVar],
        *,
        type_id: int = None,
        namespace: str = None,
        typename: str = None,
        serializer=None,
    ):
        return self._register_type(
            cls,
            type_id=type_id,
            namespace=namespace,
            typename=typename,
            serializer=serializer,
        )

    def _register_type(
        self,
        cls: Union[type, TypeVar],
        *,
        type_id: int = None,
        namespace: str = None,
        typename: str = None,
        serializer=None,
        internal=False,
    ):
        """Register class with given type id or typename. If typename is not None, it will be used for
        cross-language serialization."""
        if serializer is not None and not isinstance(serializer, Serializer):
            try:
                serializer = serializer(self.fury, cls)
            except BaseException:
                try:
                    serializer = serializer(self.fury)
                except BaseException:
                    serializer = serializer()
        n_params = len({typename, type_id, None}) - 1
        if n_params == 0 and typename is None:
            type_id = self._next_type_id()
        if n_params == 2:
            raise TypeError(
                f"type name {typename} and id {type_id} should not be set at the same time"
            )
        if type_id not in {0, None}:
            # multiple class can have same tpe id
            if type_id in self._type_id_to_classinfo and cls in self._classes_info:
                raise TypeError(f"{cls} registered already")
        elif cls in self._classes_info:
            raise TypeError(f"{cls} registered already")
        register_type = (
            self._register_xtype
            if self.fury.language == Language.XLANG
            else self._register_pytype
        )
        return register_type(
            cls,
            type_id=type_id,
            namespace=namespace,
            typename=typename,
            serializer=serializer,
            internal=internal,
        )

    def _register_xtype(
        self,
        cls: Union[type, TypeVar],
        *,
        type_id: int = None,
        namespace: str = None,
        typename: str = None,
        serializer=None,
        internal=False,
    ):
        if serializer is None:
            if issubclass(cls, enum.Enum):
                serializer = EnumSerializer(self.fury, cls)
                type_id = (
                    TypeId.NAMED_ENUM
                    if type_id is None
                    else ((type_id << 8) + TypeId.ENUM)
                )
            else:
                serializer = ComplexObjectSerializer(self.fury, cls)
                type_id = (
                    TypeId.NAMED_STRUCT
                    if type_id is None
                    else ((type_id << 8) + TypeId.STRUCT)
                )
        elif not internal:
            type_id = (
                TypeId.NAMED_EXT if type_id is None else ((type_id << 8) + TypeId.EXT)
            )
        return self.__register_type(
            cls,
            type_id=type_id,
            serializer=serializer,
            namespace=namespace,
            typename=typename,
            internal=internal,
        )

    def _register_pytype(
        self,
        cls: Union[type, TypeVar],
        *,
        type_id: int = None,
        namespace: str = None,
        typename: str = None,
        serializer: Serializer = None,
        internal: bool = False,
    ):
        return self.__register_type(
            cls,
            type_id=type_id,
            namespace=namespace,
            typename=typename,
            serializer=serializer,
            internal=internal,
        )

    def __register_type(
        self,
        cls: Union[type, TypeVar],
        *,
        type_id: int = None,
        namespace: str = None,
        typename: str = None,
        serializer: Serializer = None,
        internal: bool = False,
    ):
        dynamic_type = type_id < 0
        if not internal and serializer is None:
            serializer = self._create_serializer(cls)
        if typename is None:
            classinfo = ClassInfo(cls, type_id, serializer, None, None, dynamic_type)
        else:
            if namespace is None:
                splits = typename.rsplit(".", 1)
                if len(splits) == 2:
                    namespace, typename = splits
            ns_metastr = self.namespace_encoder.encode(namespace or "")
            ns_meta_bytes = self.metastring_resolver.get_metastr_bytes(ns_metastr)
            type_metastr = self.typename_encoder.encode(typename)
            type_meta_bytes = self.metastring_resolver.get_metastr_bytes(type_metastr)
            classinfo = ClassInfo(
                cls, type_id, serializer, ns_meta_bytes, type_meta_bytes, dynamic_type
            )
            self._named_type_to_classinfo[(namespace, typename)] = classinfo
            self._ns_type_to_classinfo[(ns_meta_bytes, type_meta_bytes)] = classinfo
        self._classes_info[cls] = classinfo
        if type_id > 0 and (
            self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)
        ):
            if type_id not in self._type_id_to_classinfo or not internal:
                self._type_id_to_classinfo[type_id] = classinfo
        self._classes_info[cls] = classinfo
        return classinfo

    def _next_type_id(self):
        type_id = self._type_id_counter = self._type_id_counter + 1
        while type_id in self._type_id_to_classinfo:
            type_id = self._type_id_counter = self._type_id_counter + 1
        return type_id

    def register_serializer(self, cls: Union[type, TypeVar], serializer):
        assert isinstance(cls, (type, TypeVar)), cls
        if cls not in self._classes_info:
            raise TypeUnregisteredError(f"{cls} not registered")
        classinfo = self._classes_info[cls]
        if self.fury.language == Language.PYTHON:
            classinfo.serializer = serializer
            return
        type_id = prev_type_id = classinfo.type_id
        self._type_id_to_classinfo.pop(prev_type_id)
        if classinfo.serializer is not serializer:
            if classinfo.typename_bytes is not None:
                type_id = classinfo.type_id & 0xFFFFFF00 | TypeId.NAMED_EXT
            else:
                type_id = classinfo.type_id & 0xFFFFFF00 | TypeId.EXT
        self._type_id_to_classinfo[type_id] = classinfo

    def get_serializer(self, cls: type):
        """
        Returns
        -------
            Returns or create serializer for the provided class
        """
        return self.get_classinfo(cls).serializer

    def get_classinfo(self, cls, create=True):
        class_info = self._classes_info.get(cls)
        if class_info is not None:
            if class_info.serializer is None:
                class_info.serializer = self._create_serializer(cls)
            return class_info
        elif not create:
            return None
        if self.language != Language.PYTHON or (
            self.require_registration and not issubclass(cls, Enum)
        ):
            raise TypeUnregisteredError(f"{cls} not registered")
        logger.info("Class %s not registered", cls)
        serializer = self._create_serializer(cls)
        type_id = None
        if self.language == Language.PYTHON:
            if isinstance(serializer, EnumSerializer):
                type_id = TypeId.NAMED_ENUM
            elif type(serializer) is PickleSerializer:
                type_id = PickleSerializer.PICKLE_CLASS_ID
            if not self.require_registration:
                if isinstance(serializer, DataClassSerializer):
                    type_id = TypeId.NAMED_STRUCT
        if type_id is None:
            raise TypeUnregisteredError(
                f"{cls} must be registered using `fury.register_type` API"
            )
        return self.__register_type(
            cls,
            type_id=type_id,
            namespace=cls.__module__,
            typename=cls.__qualname__,
            serializer=serializer,
        )

    def _create_serializer(self, cls):
        for clz in cls.__mro__:
            class_info = self._classes_info.get(clz)
            if (
                class_info
                and class_info.serializer
                and class_info.serializer.support_subclass()
            ):
                serializer = type(class_info.serializer)(self.fury, cls)
                break
        else:
            if dataclasses.is_dataclass(cls):
                from pyfury import DataClassSerializer

                serializer = DataClassSerializer(self.fury, cls)
            elif issubclass(cls, enum.Enum):
                serializer = EnumSerializer(self.fury, cls)
            else:
                serializer = PickleSerializer(self.fury, cls)
        return serializer

    def _load_metabytes_to_classinfo(self, ns_metabytes, type_metabytes):
        typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes))
        if typeinfo is not None:
            return typeinfo
        ns = ns_metabytes.decode(self.namespace_decoder)
        typename = type_metabytes.decode(self.typename_decoder)
        # the hash computed between languages may be different.
        typeinfo = self._named_type_to_classinfo.get((ns, typename))
        if typeinfo is not None:
            self._ns_type_to_classinfo[(ns_metabytes, type_metabytes)] = typeinfo
            return typeinfo
        cls = load_class(ns + "#" + typename)
        classinfo = self.get_classinfo(cls)
        self._ns_type_to_classinfo[(ns_metabytes, type_metabytes)] = classinfo
        return classinfo

    def write_typeinfo(self, buffer, classinfo):
        if classinfo.dynamic_type:
            return
        type_id = classinfo.type_id
        internal_type_id = type_id & 0xFF
        buffer.write_varuint32(type_id)
        if TypeId.is_namespaced_type(internal_type_id):
            self.metastring_resolver.write_meta_string_bytes(
                buffer, classinfo.namespace_bytes
            )
            self.metastring_resolver.write_meta_string_bytes(
                buffer, classinfo.typename_bytes
            )

    def read_typeinfo(self, buffer):
        type_id = buffer.read_varuint32()
        internal_type_id = type_id & 0xFF
        if TypeId.is_namespaced_type(internal_type_id):
            ns_metabytes = self.metastring_resolver.read_meta_string_bytes(buffer)
            type_metabytes = self.metastring_resolver.read_meta_string_bytes(buffer)
            typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes))
            if typeinfo is None:
                ns = ns_metabytes.decode(self.namespace_decoder)
                typename = type_metabytes.decode(self.typename_decoder)
                typeinfo = self._named_type_to_classinfo.get((ns, typename))
                if typeinfo is not None:
                    self._ns_type_to_classinfo[
                        (ns_metabytes, type_metabytes)
                    ] = typeinfo
                    return typeinfo
                # TODO(chaokunyang) generate a dynamic class and serializer
                #  when meta share is enabled.
                name = ns + "." + typename if ns else typename
                raise TypeUnregisteredError(f"{name} not registered")
            return typeinfo
        else:
            return self._type_id_to_classinfo[type_id]

    def reset(self):
        pass

    def reset_read(self):
        pass

    def reset_write(self):
        pass
