# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Tuple

try:
    import numpy as np
except ImportError:
    np = None

logger = logging.getLogger(__name__)


NULL_FLAG = -3
# This flag indicates that object is a not-null value.
# We don't use another byte to indicate REF, so that we can save one byte.
REF_FLAG = -2
# this flag indicates that the object is a non-null value.
NOT_NULL_VALUE_FLAG = -1
# this flag indicates that the object is a referencable and first read.
REF_VALUE_FLAG = 0


class RefResolver(ABC):
    @abstractmethod
    def write_ref_or_null(self, buffer, obj):
        """
        Write reference and tag for the obj if the obj has been written
        previously, write null/not-null tag otherwise.

        Returns
        -------
            true if no bytes need to be written for the object.
        """

    @abstractmethod
    def read_ref_or_null(self, buffer):
        """
        Returns
        -------
            `REF_FLAG` if a reference to a previously read object was
            read.
            `NULL_FLAG` if the object is null.
            `REF_VALUE_FLAG` if the object is not null and reference tracking is
             not enabled or the object is first read.
        """

    @abstractmethod
    def preserve_ref_id(self) -> int:
        """
        Preserve a reference id, which is used by `setReadObject` to set up
        reference for object that is first deserialized.

        Returns
        -------
            a reference id or -1 if reference is not enabled.
        """

    @abstractmethod
    def try_preserve_ref_id(self, buffer) -> int:
        """
        Preserve and return a `refId` which is `>=` {@link NOT_NULL_VALUE_FLAG}
        if the value is not null. If the value is referencable value, the `refId`
        will be {@link #preserveReferenceId}.

        Returns
        -------
            a reference id
        """

    @abstractmethod
    def reference(self, obj):
        """
        Call this method immediately after composited object such as object
        array/map/collection/bean is created, so that circular reference can
        be deserialized correctly.
        """

    @abstractmethod
    def get_read_object(self, id_=None):
        """
        Returns
        -------
            the object for the specified id.
        """

    @abstractmethod
    def set_read_object(self, id_, obj):
        """
        Sets the id for an object that has been read.

        Parameters
        ----------
        id_: int
            The id from {@link #nextReadRefId)}.
        obj:
            the object that has been read
        """

    @abstractmethod
    def reset(self):
        pass

    @abstractmethod
    def reset_write(self):
        pass

    @abstractmethod
    def reset_read(self):
        pass


class MapRefResolver(RefResolver):
    written_objects: Dict[int, Tuple[int, Any]]  # id(obj) -> (ref_id, obj)
    read_objects: List[Any]
    read_ref_ids: List[int]

    def __init__(self):
        self.written_objects = dict()
        self.read_objects = list()
        self.read_ref_ids = list()
        self.read_object = None

    def write_ref_or_null(self, buffer, obj):
        if obj is None:
            buffer.write_int8(NULL_FLAG)
            return True
        else:
            object_id = id(obj)
            written_id = self.written_objects.get(object_id, None)
            # The obj has been written previously.
            if written_id is not None:
                buffer.write_int8(REF_FLAG)
                buffer.write_varuint32(written_id[0])
                return True
            else:
                written_id = len(self.written_objects)
                # Hold object to avoid tmp object gc when serialize nested
                # fields/objects.
                self.written_objects[object_id] = (written_id, obj)
            buffer.write_int8(REF_VALUE_FLAG)
            return False

    def read_ref_or_null(self, buffer):
        head_flag = buffer.read_int8()
        if head_flag == REF_FLAG:
            # read reference id and get object from reference resolver
            ref_id = buffer.read_varuint32()
            self.read_object = self.get_read_object(ref_id)
            return REF_FLAG
        else:
            self.read_object = None
            return head_flag

    def preserve_ref_id(self) -> int:
        next_read_ref_id = len(self.read_objects)
        self.read_objects.append(None)
        self.read_ref_ids.append(next_read_ref_id)
        return next_read_ref_id

    def try_preserve_ref_id(self, buffer) -> int:
        head_flag = buffer.read_int8()
        if head_flag == REF_FLAG:
            # read reference id and get object from reference resolver
            ref_id = buffer.read_varuint32()
            self.read_object = self.get_read_object(id_=ref_id)
        else:
            self.read_object = None
            if head_flag == REF_VALUE_FLAG:
                return self.preserve_ref_id()
        # `head_flag` except `REF_FLAG` can be used as stub reference id because we use
        # `refId >= NOT_NULL_VALUE_FLAG` to read data.
        return head_flag

    def reference(self, obj):
        ref_id = self.read_ref_ids.pop()
        self.set_read_object(ref_id, obj)

    def get_read_object(self, id_=None):
        if id_ is None:
            return self.read_object
        return self.read_objects[id_]

    def set_read_object(self, id_, obj):
        if id_ >= 0:
            if id_ >= len(self.read_objects):
                raise RuntimeError(f"Ref id {id_} invalid")
            self.read_objects[id_] = obj

    def reset(self):
        self.reset_write()
        self.reset_read()

    def reset_write(self):
        self.written_objects.clear()

    def reset_read(self):
        self.read_objects.clear()
        self.read_ref_ids.clear()
        self.read_object = None


class NoRefResolver(RefResolver):
    def write_ref_or_null(self, buffer, obj):
        if obj is None:
            buffer.write_int8(NULL_FLAG)
            return True
        else:
            buffer.write_int8(NOT_NULL_VALUE_FLAG)
            return False

    def read_ref_or_null(self, buffer):
        return buffer.read_int8()

    def preserve_ref_id(self) -> int:
        return -1

    def try_preserve_ref_id(self, buffer) -> int:
        # `NOT_NULL_VALUE_FLAG` can be used as stub reference id because we use
        # `refId >= NOT_NULL_VALUE_FLAG` to read data.
        return buffer.read_int8()

    def reference(self, obj):
        pass

    def get_read_object(self, id_=None):
        return None

    def set_read_object(self, id_, obj):
        pass

    def reset(self):
        pass

    def reset_write(self):
        pass

    def reset_read(self):
        pass
