# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import datetime
import typing

import pyarrow as pa

from functools import partial
from typing import Optional
from pyfury.type import get_qualified_classname, TypeVisitor, infer_field

__class_map__ = {}
__schemas__ = {}  # ensure `id(schema)` doesn't get duplicate.


def get_cls_by_schema(schema):
    id_ = id(schema)
    if id_ not in __class_map__:
        meta = {} if schema.metadata is None else schema.metadata
        cls_name = meta.get(b"cls", b"").decode()
        if cls_name:
            import importlib

            module_name, class_name = cls_name.rsplit(".", 1)
            mod = importlib.import_module(module_name)
            cls_ = getattr(mod, class_name)
        else:
            from pyfury.type import record_class_factory

            cls_ = record_class_factory(
                "Record" + str(id(schema)), [f.name for f in schema]
            )
        __class_map__[id_] = cls_
        __schemas__[id_] = schema
    return __class_map__[id_]


def remove_schema(schema):
    __schemas__.pop(id(schema))


def reset():
    __class_map__.clear()
    __schemas__.clear()


_supported_types = {
    pa.bool_,
    pa.int8,
    pa.int16,
    pa.int32,
    pa.int64,
    pa.float32,
    pa.float64,
    str,
    bytes,
    typing.List,
    typing.Dict,
}
_supported_types_str = [
    f"{t.__module__}.{getattr(t, '__name__', t)}" for t in _supported_types
]
_supported_types_mapping = {t: t for t in _supported_types}
_supported_types_mapping.update(
    {
        str: pa.utf8,
        bytes: pa.binary,
        list: pa.list_,
        dict: pa.map_,
        typing.List: pa.list_,
        typing.Dict: pa.map_,
        bool: pa.bool_,
        datetime.date: pa.date32,
        datetime.datetime: partial(pa.timestamp, "us"),
    }
)


def infer_schema(clz, types_path=None) -> pa.Schema:
    types_path = list(types_path or [])
    type_hints = typing.get_type_hints(clz)
    keys = sorted(type_hints.keys())
    fields = [
        infer_field(
            field_name,
            type_hints[field_name],
            ArrowTypeVisitor(),
            types_path=types_path,
        )
        for field_name in keys
    ]
    return pa.schema(fields, metadata={"cls": get_qualified_classname(clz)})


class ArrowTypeVisitor(TypeVisitor):
    def visit_list(self, field_name, elem_type, types_path=None):
        # Infer type recursively for type such as List[Dict[str, str]]
        elem_field = infer_field("item", elem_type, self, types_path=types_path)
        return pa.field(field_name, pa.list_(elem_field.type))

    def visit_dict(self, field_name, key_type, value_type, types_path=None):
        # Infer type recursively for type such as Dict[str, Dict[str, str]]
        key_field = infer_field("key", key_type, self, types_path=types_path)
        value_field = infer_field("value", value_type, self, types_path=types_path)
        return pa.field(field_name, pa.map_(key_field.type, value_field.type))

    def visit_customized(self, field_name, type_, types_path=None):
        # type_ is a pojo
        pojo_schema = infer_schema(type_)
        fields = list(pojo_schema)
        return pa.field(
            field_name,
            pa.struct(fields),
            metadata={"cls": get_qualified_classname(type_)},
        )

    def visit_other(self, field_name, type_, types_path=None):
        # use _supported_types_mapping instead of _supported_types, because
        # typing.List/typing.Dict's origin will be list/dict
        if type_ not in _supported_types_mapping:
            raise TypeError(
                f"Type {type_} not supported, currently only "
                f"compositions of {_supported_types_str} are supported. "
                f"types_path is {types_path}"
            )
        arrow_type_func = _supported_types_mapping.get(type_)
        return pa.field(field_name, arrow_type_func())


def infer_data_type(clz) -> Optional[pa.DataType]:
    try:
        return infer_field("", clz, ArrowTypeVisitor()).type
    except TypeError:
        return None


def get_type_id(clz) -> Optional[int]:
    type_ = infer_data_type(clz)
    if type_:
        return type_.id
    else:
        return None


def compute_schema_hash(schema: pa.Schema):
    hash_ = 17
    for f in schema:
        hash_ = _compute_hash(hash_, f.type)
    return hash_


def _compute_hash(hash_: int, type_: pa.DataType):
    while True:
        h = hash_ * 31 + type_.id
        if h > 2**63 - 1:
            hash_ = hash_ >> 2
        else:
            hash_ = h
            break
    types = []
    if isinstance(type_, pa.ListType):
        types.append(type_.value_type)
    elif isinstance(type_, pa.MapType):
        types.append(type_.key_type)
        types.append(type_.item_type)
    elif isinstance(type_, pa.StructType):
        types.extend([f.type for f in type_])
    else:
        assert (
            type_.num_fields == 0
        ), f"field type should not be nested, but got type {type_}."

    for t in types:
        hash_ = _compute_hash(hash_, t)
    return hash_
