#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import threading
import zlib

from ... import compat, errors, options
from ...compat import BytesIO, Enum, Semaphore, six
from ...lib.monotonic import monotonic
from ..errors import TunnelError

try:
    from urllib3.exceptions import ReadTimeoutError
except ImportError:
    from requests import ReadTimeout as ReadTimeoutError

MICRO_SEC_PER_SEC = 1000000

# used for test case to force thread io
_FORCE_THREAD = False

if compat.LESS_PY32:
    mv_to_bytes = lambda v: bytes(bytearray(v))
else:
    mv_to_bytes = bytes


if compat.six.PY3:

    def cast_memoryview(v):
        if not isinstance(v, memoryview):
            v = memoryview(v)
        return v.cast("B")

else:

    def cast_memoryview(v):
        if not isinstance(v, memoryview):
            v = memoryview(v)
        return v


class RequestsIO(object):
    CHUNK_SIZE = 256 * 1024

    def __new__(cls, *args, **kwargs):
        if cls is RequestsIO:
            if (
                not isinstance(threading.current_thread(), threading._MainThread)
                or _FORCE_THREAD
            ):
                return object.__new__(ThreadRequestsIO)
            elif GreenletRequestsIO is not None:
                return object.__new__(GreenletRequestsIO)
            else:
                return object.__new__(ThreadRequestsIO)
        else:
            return object.__new__(cls)

    def __init__(self, post_call, chunk_size=None, record_io_time=False):
        self._buf = BytesIO()
        self._resp = None
        self._async_err = None
        self._chunk_size = chunk_size or self.CHUNK_SIZE

        self._post_call = post_call
        self._wait_obj = None
        self._record_io_time = record_io_time
        self._io_time_ms = 0
        self._io_start_time = 0
        self._io_end_time = 0

    def _async_func(self):
        try:
            if self._record_io_time:
                self._io_start_time = monotonic()
            self._resp = self._post_call(self.data_generator())
        except:
            self._async_err = sys.exc_info()
        self._wait_obj = None

    def _reraise_errors(self):
        if self._async_err is not None:
            ex_type, ex_value, tb = self._async_err
            six.reraise(ex_type, ex_value, tb)

    @property
    def io_time_ms(self):
        return self._io_time_ms

    def data_generator(self):
        if self._record_io_time:
            self._io_time_ms += int(
                MICRO_SEC_PER_SEC * (monotonic() - self._io_start_time)
            )
        chunk_size = self._chunk_size
        while True:
            data = self.get()
            if data is not None:
                data = memoryview(data)
                while data:
                    to_send = mv_to_bytes(data[:chunk_size])
                    data = data[chunk_size:]

                    if self._record_io_time:
                        ts = monotonic()

                    yield to_send

                    if self._record_io_time:
                        self._io_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
            else:
                break
        if self._record_io_time:
            self._io_end_time = monotonic()

    def start(self):
        pass

    def get(self):
        raise NotImplementedError

    def put(self, data):
        raise NotImplementedError

    def write(self, data):
        self._buf.write(data)
        if self._buf.tell() >= self._chunk_size:
            chunk = self._buf.getvalue()
            self._buf = BytesIO()
            self.put(chunk)
            self._reraise_errors()

    def flush(self):
        if self._buf.tell():
            chunk = self._buf.getvalue()
            self._buf = BytesIO()
            self.put(chunk)
            self._reraise_errors()

    def finish(self):
        self.flush()
        self.put(None)
        wait_obj = self._wait_obj
        if wait_obj and wait_obj.is_alive():
            wait_obj.join()

        if self._record_io_time:
            self._io_time_ms += int(
                MICRO_SEC_PER_SEC * (monotonic() - self._io_end_time)
            )

        self._reraise_errors()
        return self._resp


class ThreadRequestsIO(RequestsIO):
    def __init__(self, post_call, chunk_size=None, record_io_time=False):
        super(ThreadRequestsIO, self).__init__(
            post_call, chunk_size, record_io_time=record_io_time
        )
        self._last_data = None
        self._sem_put = Semaphore(1)
        self._sem_get = Semaphore(0)
        self._wait_obj = threading.Thread(target=self._async_func)
        self._wait_obj.daemon = True
        self._acquire_timeout = options.connect_timeout

    def _async_func(self):
        try:
            super(ThreadRequestsIO, self)._async_func()
        finally:
            # make sure subsequent put() call does not get stuck
            self._sem_put.release()

    def start(self):
        self._wait_obj.start()

    def get(self):
        self._sem_get.acquire()
        data = self._last_data
        self._sem_put.release()
        return data

    def put(self, data):
        self._reraise_errors()
        assert self._wait_obj is not None and self._wait_obj.is_alive()
        try:
            rc = self._sem_put.acquire(timeout=self._acquire_timeout)
            if not rc:
                raise TimeoutError("Wait for data semaphore timed out")
            self._reraise_errors()
            self._last_data = data
        except:
            self._last_data = None
            raise
        finally:
            self._sem_get.release()


try:
    from greenlet import greenlet

    class GreenletRequestsIO(RequestsIO):
        def __init__(self, post_call, chunk_size=None, record_io_time=False):
            super(GreenletRequestsIO, self).__init__(
                post_call, chunk_size, record_io_time=record_io_time
            )
            self._cur_greenlet = greenlet.getcurrent()
            self._writer_greenlet = greenlet(self._async_func)
            self._last_data = None
            self._writer_greenlet.switch()

        def get(self):
            self._cur_greenlet.switch()
            return self._last_data

        def put(self, data):
            self._last_data = data
            # handover control
            self._writer_greenlet.switch()

except ImportError:
    GreenletRequestsIO = None


class CompressOption(object):
    class CompressAlgorithm(Enum):
        ODPS_RAW = "RAW"
        ODPS_ZLIB = "ZLIB"
        ODPS_SNAPPY = "SNAPPY"
        ODPS_ZSTD = "ZSTD"
        ODPS_LZ4 = "LZ4"
        ODPS_ARROW_LZ4 = "ARROW_LZ4"

        def get_encoding(self, legacy=True):
            cls = type(self)
            if legacy:
                if self == cls.ODPS_RAW:
                    return None
                elif self == cls.ODPS_ZLIB:
                    return "deflate"
                elif self == cls.ODPS_ZSTD:
                    return "zstd"
                elif self == cls.ODPS_LZ4:
                    return "x-lz4-frame"
                elif self == cls.ODPS_SNAPPY:
                    return "x-snappy-framed"
                elif self == cls.ODPS_ARROW_LZ4:
                    return "x-odps-lz4-frame"
                else:
                    raise TunnelError("invalid compression option")
            else:
                if self == cls.ODPS_RAW:
                    return None
                elif self == cls.ODPS_ZSTD:
                    return "ZSTD"
                elif self == cls.ODPS_LZ4 or self == cls.ODPS_ARROW_LZ4:
                    return "LZ4_FRAME"
                else:
                    raise TunnelError("invalid compression option")

        @classmethod
        def from_encoding(cls, encoding):
            encoding = encoding.lower() if encoding else None
            if encoding is None or encoding == "identity":
                return cls.ODPS_RAW
            elif encoding == "deflate":
                return cls.ODPS_ZLIB
            elif encoding == "zstd":
                return cls.ODPS_ZSTD
            elif encoding == "x-lz4-frame":
                return cls.ODPS_LZ4
            elif encoding == "x-snappy-framed":
                return cls.ODPS_SNAPPY
            elif encoding == "x-odps-lz4-frame" or encoding == "lz4_frame":
                return cls.ODPS_ARROW_LZ4
            else:
                raise TunnelError("invalid encoding name %s" % encoding)

    def __init__(
        self, compress_algo=CompressAlgorithm.ODPS_ZLIB, level=None, strategy=None
    ):
        compress_algo = compress_algo or self.CompressAlgorithm.ODPS_ZLIB
        if isinstance(compress_algo, CompressOption.CompressAlgorithm):
            self.algorithm = compress_algo
        else:
            self.algorithm = CompressOption.CompressAlgorithm(compress_algo.upper())
        self.level = level or 1
        self.strategy = strategy or 0


_lz4_algorithms = (
    CompressOption.CompressAlgorithm.ODPS_LZ4,
    CompressOption.CompressAlgorithm.ODPS_ARROW_LZ4,
)


def get_compress_stream(buffer, compress_option=None):
    algo = getattr(compress_option, "algorithm", None)

    if algo is None or algo == CompressOption.CompressAlgorithm.ODPS_RAW:
        return buffer
    elif algo == CompressOption.CompressAlgorithm.ODPS_ZLIB:
        return DeflateOutputStream(buffer, level=compress_option.level)
    elif algo == CompressOption.CompressAlgorithm.ODPS_ZSTD:
        return ZstdOutputStream(buffer, level=compress_option.level)
    elif algo == CompressOption.CompressAlgorithm.ODPS_SNAPPY:
        return SnappyOutputStream(buffer, level=compress_option.level)
    elif algo in _lz4_algorithms:
        return LZ4OutputStream(buffer, level=compress_option.level)
    else:
        raise errors.InvalidArgument("Invalid compression algorithm %s." % algo)


def get_decompress_stream(resp, compress_option=None, requests=True):
    algo = getattr(compress_option, "algorithm", None)
    if algo is None or algo == CompressOption.CompressAlgorithm.ODPS_RAW:
        stream_cls = RequestsInputStream  # create a file-like object from body
    elif algo == CompressOption.CompressAlgorithm.ODPS_ZLIB:
        stream_cls = DeflateRequestsInputStream
    elif algo == CompressOption.CompressAlgorithm.ODPS_ZSTD:
        stream_cls = ZstdRequestsInputStream
    elif algo == CompressOption.CompressAlgorithm.ODPS_SNAPPY:
        stream_cls = SnappyRequestsInputStream
    elif algo in _lz4_algorithms:
        stream_cls = LZ4RequestsInputStream
    else:
        raise errors.InvalidArgument("Invalid compression algorithm %s." % algo)

    if not requests:
        stream_cls = stream_cls.get_raw_input_stream_class()
    return stream_cls(resp)


class CompressOutputStream(object):
    def __init__(self, output, level=1):
        self._compressor = self._get_compressor(level=level)
        self._output = output

    def _get_compressor(self, level=1):
        raise NotImplementedError

    def write(self, data):
        if self._compressor:
            compressed_data = self._compressor.compress(data)
            if compressed_data:
                self._output.write(compressed_data)
            else:
                pass  # buffering
        else:
            self._output.write(data)

    def flush(self):
        if self._compressor:
            remaining = self._compressor.flush()
            if remaining:
                self._output.write(remaining)


class DeflateOutputStream(CompressOutputStream):
    def _get_compressor(self, level=1):
        return zlib.compressobj(level)


class SnappyOutputStream(CompressOutputStream):
    def _get_compressor(self, level=1):
        try:
            import snappy
        except ImportError:
            raise errors.DependencyNotInstalledError(
                "python-snappy library is required for snappy support"
            )
        return snappy.StreamCompressor()


class ZstdOutputStream(CompressOutputStream):
    def _get_compressor(self, level=1):
        try:
            import zstandard
        except ImportError:
            raise errors.DependencyNotInstalledError(
                "zstandard library is required for zstd support"
            )
        return zstandard.ZstdCompressor().compressobj()


class LZ4OutputStream(CompressOutputStream):
    def _get_compressor(self, level=1):
        try:
            import lz4.frame
        except ImportError:
            raise errors.DependencyNotInstalledError(
                "lz4 library is required for lz4 support"
            )
        self._begun = False
        return lz4.frame.LZ4FrameCompressor(compression_level=level)

    def write(self, data):
        if not self._begun:
            self._output.write(self._compressor.begin())
            self._begun = True
        super(LZ4OutputStream, self).write(data)


class SimpleInputStream(object):
    READ_BLOCK_SIZE = 1024 * 64

    def __init__(self, input):
        self._input = input
        self._internal_buffer = memoryview(b"")
        self._buffered_len = 0
        self._buffered_pos = 0
        self._pos = 0
        self._closed = False

    @staticmethod
    def readable():
        return True

    def __len__(self):
        return self._pos

    def read(self, limit):
        if self._closed:
            raise IOError("closed")

        if limit <= self._buffered_len - self._buffered_pos:
            mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + limit]
            self._buffered_pos += len(mv)
            self._pos += len(mv)
            return mv_to_bytes(mv)

        bufs = list()
        size_left = limit
        while size_left > 0:
            content = self._internal_read(size_left)
            if not content:
                break
            bufs.append(content)
            size_left -= len(content)
        ret = bytes().join(bufs)
        self._pos += len(ret)
        return ret

    def peek(self):
        if self._buffered_pos == self._buffered_len:
            self._refill_buffer()

        if self._buffered_pos == self._buffered_len:
            # still nothing can be read
            return None
        return self._internal_buffer[self._buffered_pos]

    def readinto(self, b):
        if self._closed:
            raise IOError("closed")

        b = cast_memoryview(b)
        limit = len(b)
        if limit <= self._buffered_len - self._buffered_pos:
            mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + limit]
            self._buffered_pos += len(mv)
            b[:limit] = mv
            self._pos += len(mv)
            return len(mv)

        pos = 0
        while pos < limit:
            rsize = self._internal_readinto(b, pos)
            if not rsize:
                break
            pos += rsize
        self._pos += pos
        return pos

    def _internal_read(self, limit):
        if self._buffered_pos == self._buffered_len:
            self._refill_buffer()
        mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + limit]
        self._buffered_pos += len(mv)
        return mv_to_bytes(mv)

    def _internal_readinto(self, b, start):
        if self._buffered_pos == self._buffered_len:
            self._refill_buffer()
        size = len(b) - start
        mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + size]
        size = len(mv)
        self._buffered_pos += size
        b[start : start + size] = mv
        return size

    def _refill_buffer(self):
        self._buffered_pos = 0
        self._buffered_len = 0

        buffer = []
        while True:
            content = self._buffer_next_chunk()
            if content is None:
                break
            if content:
                length = len(content)
                self._buffered_len += length
                buffer.append(content)
                break

        if len(buffer) == 1:
            self._internal_buffer = memoryview(buffer[0])
        else:
            self._internal_buffer = memoryview(bytes().join(buffer))

    def _read_block(self):
        content = self._input.read(self.READ_BLOCK_SIZE)
        return content if content else None

    def _buffer_next_chunk(self):
        return self._read_block()

    @property
    def closed(self):
        return self._closed

    def close(self):
        self._closed = True


class DecompressInputStream(SimpleInputStream):
    def __init__(self, input):
        super(DecompressInputStream, self).__init__(input)
        self._decompressor = self._get_decompressor()

    def _get_decompressor(self):
        raise NotImplementedError

    def _buffer_next_chunk(self):
        data = self._read_block()
        if data is None:
            return None
        if data:
            return self._decompressor.decompress(data)
        else:
            return self._decompressor.flush()


class DeflateInputStream(DecompressInputStream):
    def _get_decompressor(self):
        return zlib.decompressobj(zlib.MAX_WBITS)


class SnappyInputStream(DecompressInputStream):
    def _get_decompressor(self):
        try:
            import snappy
        except ImportError:
            raise errors.DependencyNotInstalledError(
                "python-snappy library is required for snappy support"
            )
        return snappy.StreamDecompressor()


class ZstdInputStream(DecompressInputStream):
    def _get_decompressor(self):
        try:
            import zstandard
        except ImportError:
            raise errors.DependencyNotInstalledError(
                "zstandard library is required for zstd support"
            )
        return zstandard.ZstdDecompressor().decompressobj()


class LZ4InputStream(DecompressInputStream):
    def _get_decompressor(self):
        try:
            import lz4.frame
        except ImportError:
            raise errors.DependencyNotInstalledError(
                "lz4 library is required for lz4 support"
            )
        return lz4.frame.LZ4FrameDecompressor()


class RawRequestsStreamMixin(object):
    _decode_content = False

    @classmethod
    def get_raw_input_stream_class(cls):
        for base in cls.__mro__:
            if (
                base is not cls
                and base is not RawRequestsStreamMixin
                and issubclass(cls, SimpleInputStream)
            ):
                return base
        return None

    def _read_block(self):
        try:
            content = self._input.raw.read(
                self.READ_BLOCK_SIZE, decode_content=self._decode_content
            )
            return content if content else None
        except ReadTimeoutError:
            if callable(options.tunnel_read_timeout_callback):
                options.tunnel_read_timeout_callback(*sys.exc_info())
            raise


class RequestsInputStream(RawRequestsStreamMixin, SimpleInputStream):
    _decode_content = True


# Requests automatically decompress gzip data!
class DeflateRequestsInputStream(RawRequestsStreamMixin, SimpleInputStream):
    _decode_content = True

    @classmethod
    def get_raw_input_stream_class(cls):
        return DeflateInputStream


class SnappyRequestsInputStream(RawRequestsStreamMixin, SnappyInputStream):
    pass


class ZstdRequestsInputStream(RawRequestsStreamMixin, ZstdInputStream):
    pass


class LZ4RequestsInputStream(RawRequestsStreamMixin, LZ4InputStream):
    pass
