# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import http.server as http_server
import zlib
from io import BytesIO as StringIO
from struct import pack, unpack

from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TCompactProtocol import getVarint, readVarint, TCompactProtocol
from thrift.Thrift import TApplicationException
from thrift.transport.TTransport import (
    TTransportException,
    TTransportBase,
    CReadableTransport,
)

xrange = range

# Import the snappy module if it is available
try:
    import snappy
except ImportError:
    # If snappy is not available, don't fail immediately.
    # Only raise an error if we actually ever need to perform snappy
    # compression.
    class DummySnappy(object):
        def compress(self, buf):
            raise TTransportException(
                TTransportException.INVALID_TRANSFORM, "snappy module not available"
            )

        def decompress(self, buf):
            raise TTransportException(
                TTransportException.INVALID_TRANSFORM, "snappy module not available"
            )

    snappy = DummySnappy()

# Import the zstd module if it is available
try:
    import zstd  # @manual
except ImportError:
    # If zstd is not available, don't fail immediately.
    # Only raise an error if we actually ever need to perform zstd
    # compression.
    class DummyZstd(object):
        def ZstdCompressor(self, write_content_size):
            raise TTransportException(
                TTransportException.INVALID_TRANSFORM, "zstd module not available"
            )

        def ZstdDecompressor(self):
            raise TTransportException(
                TTransportException.INVALID_TRANSFORM, "zstd module not available"
            )

    zstd = DummyZstd()


# Definitions from THeader.h


class CLIENT_TYPE:
    HEADER = 0
    FRAMED_DEPRECATED = 1
    UNFRAMED_DEPRECATED = 2
    HTTP_SERVER = 3
    HTTP_CLIENT = 4
    FRAMED_COMPACT = 5
    HTTP_GET = 7
    UNKNOWN = 8
    UNFRAMED_COMPACT_DEPRECATED = 9


class HEADER_FLAG:
    SUPPORT_OUT_OF_ORDER = 0x01
    DUPLEX_REVERSE = 0x08


class TRANSFORM:
    NONE = 0x00
    ZLIB = 0x01
    HMAC = 0x02
    SNAPPY = 0x03
    QLZ = 0x04
    ZSTD = 0x05


class INFO:
    NORMAL = 1
    PERSISTENT = 2


T_BINARY_PROTOCOL = 0
T_COMPACT_PROTOCOL = 2
HEADER_MAGIC = 0x0FFF0000
PACKED_HEADER_MAGIC: bytes = pack(b"!H", HEADER_MAGIC >> 16)
HEADER_MASK = 0xFFFF0000
FLAGS_MASK = 0x0000FFFF
HTTP_SERVER_MAGIC = 0x504F5354  # POST
HTTP_CLIENT_MAGIC = 0x48545450  # HTTP
HTTP_GET_CLIENT_MAGIC = 0x47455420  # GET
HTTP_HEAD_CLIENT_MAGIC = 0x48454144  # HEAD
BIG_FRAME_MAGIC = 0x42494746  # BIGF
MAX_FRAME_SIZE = 0x3FFFFFFF
MAX_BIG_FRAME_SIZE = 2 ** 61 - 1


class THeaderTransport(TTransportBase, CReadableTransport):
    """Transport that sends headers.  Also understands framed/unframed/HTTP
    transports and will do the right thing"""

    __max_frame_size = MAX_FRAME_SIZE

    # Defaults to current user, but there is also a setter below.
    __identity = None
    __first_request = True
    IDENTITY_HEADER = "identity"
    ID_VERSION_HEADER = "id_version"
    ID_VERSION = "1"
    CLIENT_METADATA_HEADER = "client_metadata"

    def __init__(self, trans, client_types=None, client_type=None):
        self.__trans = trans
        self.__rbuf = StringIO()
        self.__rbuf_frame = False
        self.__wbuf = StringIO()
        self.seq_id = 0
        self.__flags = 0
        self.__read_transforms = []
        self.__write_transforms = []
        self.__supported_client_types = set(client_types or (CLIENT_TYPE.HEADER,))
        self.__proto_id = T_COMPACT_PROTOCOL  # default to compact like c++
        self.__client_type = client_type or CLIENT_TYPE.HEADER
        self.__read_headers = {}
        self.__read_persistent_headers = {}
        self.__write_headers = {}
        self.__write_persistent_headers = {}

        if self.__client_type in (
            CLIENT_TYPE.UNFRAMED_DEPRECATED,
            CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED,
            CLIENT_TYPE.FRAMED_DEPRECATED,
            CLIENT_TYPE.FRAMED_COMPACT,
        ):
            self.__client_type = CLIENT_TYPE.HEADER

        self.__supported_client_types.add(self.__client_type)

        # If we support unframed binary / framed binary also support compact
        if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types:
            self.__supported_client_types.add(CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED)
        if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types:
            self.__supported_client_types.add(CLIENT_TYPE.FRAMED_COMPACT)

    def set_header_flag(self, flag):
        self.__flags |= flag

    def clear_header_flag(self, flag):
        self.__flags &= ~flag

    def header_flags(self):
        return self.__flags

    def set_max_frame_size(self, size):
        if size > MAX_BIG_FRAME_SIZE:
            raise TTransportException(
                TTransportException.INVALID_FRAME_SIZE,
                "Cannot set max frame size > %s" % MAX_BIG_FRAME_SIZE,
            )
        if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER:
            raise TTransportException(
                TTransportException.INVALID_FRAME_SIZE,
                "Cannot set max frame size > %s for clients other than HEADER"
                % MAX_FRAME_SIZE,
            )
        self.__max_frame_size = size

    def get_peer_identity(self):
        if self.IDENTITY_HEADER in self.__read_headers:
            if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION:
                return self.__read_headers[self.IDENTITY_HEADER]
        return None

    def set_identity(self, identity):
        self.__identity = identity

    def get_protocol_id(self):
        return self.__proto_id

    def set_protocol_id(self, proto_id):
        self.__proto_id = proto_id

    def set_header(self, str_key, str_value):
        self.__write_headers[str_key] = str_value

    def get_write_headers(self):
        return self.__write_headers

    def get_headers(self):
        return self.__read_headers

    def clear_headers(self):
        self.__write_headers.clear()

    def set_persistent_header(self, str_key, str_value):
        self.__write_persistent_headers[str_key] = str_value

    def get_write_persistent_headers(self):
        return self.__write_persistent_headers

    def clear_persistent_headers(self):
        self.__write_persistent_headers.clear()

    def add_transform(self, trans_id):
        self.__write_transforms.append(trans_id)

    def _reset_protocol(self):
        # HTTP calls that are one way need to flush here.
        if self.__client_type == CLIENT_TYPE.HTTP_SERVER:
            self.flush()
        # set to anything except unframed
        self.__client_type = CLIENT_TYPE.UNKNOWN
        # Read header bytes to check which protocol to decode
        self.readFrame(0)

    def getTransport(self):
        return self.__trans

    def isOpen(self):
        return self.getTransport().isOpen()

    def open(self):
        return self.getTransport().open()

    def close(self):
        return self.getTransport().close()

    def read(self, sz):
        ret = self.__rbuf.read(sz)
        if len(ret) == sz:
            return ret

        if self.__client_type in (
            CLIENT_TYPE.UNFRAMED_DEPRECATED,
            CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED,
        ):
            return ret + self.getTransport().readAll(sz - len(ret))

        self.readFrame(sz - len(ret))
        return ret + self.__rbuf.read(sz - len(ret))

    readAll = read  # TTransportBase.readAll does a needless copy here.

    def readFrame(self, req_sz):
        self.__rbuf_frame = True
        word1 = self.getTransport().readAll(4)
        sz = unpack("!I", word1)[0]
        proto_id = word1[0]
        if proto_id == TBinaryProtocol.PROTOCOL_ID:
            # unframed
            self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED
            self.__proto_id = T_BINARY_PROTOCOL
            if req_sz <= 4:  # check for reads < 0.
                self.__rbuf = StringIO(word1)
            else:
                self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4))
        elif proto_id == TCompactProtocol.PROTOCOL_ID:
            self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED
            self.__proto_id = T_COMPACT_PROTOCOL
            if req_sz <= 4:  # check for reads < 0.
                self.__rbuf = StringIO(word1)
            else:
                self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4))
        elif sz == HTTP_SERVER_MAGIC:
            self.__client_type = CLIENT_TYPE.HTTP_SERVER
            mf = self.getTransport().handle.makefile("rb", -1)

            self.handler = RequestHandler(mf, "client_address:port", "")
            self.header = self.handler.wfile
            self.__rbuf = StringIO(self.handler.data)
        else:
            if sz == BIG_FRAME_MAGIC:
                sz = unpack("!Q", self.getTransport().readAll(8))[0]
            # could be header format or framed.  Check next two bytes.
            magic = self.getTransport().readAll(2)
            proto_id = magic[0]
            if proto_id == TCompactProtocol.PROTOCOL_ID:
                self.__client_type = CLIENT_TYPE.FRAMED_COMPACT
                self.__proto_id = T_COMPACT_PROTOCOL
                _frame_size_check(sz, self.__max_frame_size, header=False)
                self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2))
            elif proto_id == TBinaryProtocol.PROTOCOL_ID:
                self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED
                self.__proto_id = T_BINARY_PROTOCOL
                _frame_size_check(sz, self.__max_frame_size, header=False)
                self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2))
            elif magic == PACKED_HEADER_MAGIC:
                self.__client_type = CLIENT_TYPE.HEADER
                _frame_size_check(sz, self.__max_frame_size)
                # flags(2), seq_id(4), header_size(2)
                n_header_meta = self.getTransport().readAll(8)
                self.__flags, self.seq_id, header_size = unpack("!HIH", n_header_meta)
                data = StringIO()
                data.write(magic)
                data.write(n_header_meta)
                data.write(self.getTransport().readAll(sz - 10))
                data.seek(10)
                self.read_header_format(sz - 10, header_size, data)
            else:
                self.__client_type = CLIENT_TYPE.UNKNOWN
                raise TTransportException(
                    TTransportException.INVALID_CLIENT_TYPE,
                    "Could not detect client transport type",
                )

        if self.__client_type not in self.__supported_client_types:
            raise TTransportException(
                TTransportException.INVALID_CLIENT_TYPE,
                "Client type {} not supported on server".format(self.__client_type),
            )

    def read_header_format(self, sz, header_size, data):
        # clear out any previous transforms
        self.__read_transforms = []

        header_size = header_size * 4
        if header_size > sz:
            raise TTransportException(
                TTransportException.INVALID_FRAME_SIZE,
                "Header size is larger than frame",
            )
        end_header = header_size + data.tell()

        self.__proto_id = readVarint(data)
        num_headers = readVarint(data)

        if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER:
            raise TTransportException(
                TTransportException.INVALID_CLIENT_TYPE,
                "Trying to recv JSON encoding over binary",
            )

        # Read the headers.  Data for each header varies.
        for _ in range(0, num_headers):
            trans_id = readVarint(data)
            if trans_id in (TRANSFORM.ZLIB, TRANSFORM.SNAPPY, TRANSFORM.ZSTD):
                self.__read_transforms.insert(0, trans_id)
            elif trans_id == TRANSFORM.HMAC:
                raise TApplicationException(
                    TApplicationException.INVALID_TRANSFORM,
                    "Hmac transform is no longer supported: %i" % trans_id,
                )
            else:
                # TApplicationException will be sent back to client
                raise TApplicationException(
                    TApplicationException.INVALID_TRANSFORM,
                    "Unknown transform in client request: %i" % trans_id,
                )

        # Clear out previous info headers.
        self.__read_headers.clear()

        # Read the info headers.
        while data.tell() < end_header:
            info_id = readVarint(data)
            if info_id == INFO.NORMAL:
                _read_info_headers(data, end_header, self.__read_headers)
            elif info_id == INFO.PERSISTENT:
                _read_info_headers(data, end_header, self.__read_persistent_headers)
            else:
                break  # Unknown header.  Stop info processing.

        if self.__read_persistent_headers:
            self.__read_headers.update(self.__read_persistent_headers)

        # Skip the rest of the header
        data.seek(end_header)

        payload = data.read(sz - header_size)

        # Read the data section.
        self.__rbuf = StringIO(self.untransform(payload))

    def write(self, buf):
        self.__wbuf.write(buf)

    def transform(self, buf):
        for trans_id in self.__write_transforms:
            if trans_id == TRANSFORM.ZLIB:
                buf = zlib.compress(buf)
            elif trans_id == TRANSFORM.SNAPPY:
                buf = snappy.compress(buf)
            elif trans_id == TRANSFORM.ZSTD:
                buf = zstd.ZstdCompressor(write_content_size=True).compress(buf)
            else:
                raise TTransportException(
                    TTransportException.INVALID_TRANSFORM,
                    "Unknown transform during send",
                )
        return buf

    def untransform(self, buf):
        for trans_id in self.__read_transforms:
            if trans_id == TRANSFORM.ZLIB:
                buf = zlib.decompress(buf)
            elif trans_id == TRANSFORM.SNAPPY:
                buf = snappy.decompress(buf)
            elif trans_id == TRANSFORM.ZSTD:
                buf = zstd.ZstdDecompressor().decompress(buf)
            if trans_id not in self.__write_transforms:
                self.__write_transforms.append(trans_id)
        return buf

    def disable_client_metadata(self):
        self.__first_request = False

    def flush(self):
        self.flushImpl(False)

    def onewayFlush(self):
        self.flushImpl(True)

    def _flushHeaderMessage(self, buf, wout, wsz):
        """Write a message for CLIENT_TYPE.HEADER

        @param buf(StringIO): Buffer to write message to
        @param wout(str): Payload
        @param wsz(int): Payload length
        """
        transform_data = StringIO()
        # For now, all transforms don't require data.
        num_transforms = len(self.__write_transforms)
        for trans_id in self.__write_transforms:
            transform_data.write(getVarint(trans_id))

        # Add in special flags.
        if self.__identity:
            self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION
            self.__write_headers[self.IDENTITY_HEADER] = self.__identity

        if self.__first_request:
            self.__first_request = False
            self.__write_headers[
                self.CLIENT_METADATA_HEADER
            ] = '{"agent":"THeaderTransport.py"}'

        info_data = StringIO()

        # Write persistent kv-headers
        _flush_info_headers(
            info_data, self.get_write_persistent_headers(), INFO.PERSISTENT
        )

        # Write non-persistent kv-headers
        _flush_info_headers(info_data, self.__write_headers, INFO.NORMAL)

        header_data = StringIO()
        header_data.write(getVarint(self.__proto_id))
        header_data.write(getVarint(num_transforms))

        header_size = transform_data.tell() + header_data.tell() + info_data.tell()

        padding_size = 4 - (header_size % 4)
        header_size = header_size + padding_size

        # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2)
        wsz += header_size + 10
        if wsz > MAX_FRAME_SIZE:
            buf.write(pack("!I", BIG_FRAME_MAGIC))
            buf.write(pack("!Q", wsz))
        else:
            buf.write(pack("!I", wsz))
        buf.write(pack("!HH", HEADER_MAGIC >> 16, self.__flags))
        buf.write(pack("!I", self.seq_id))
        buf.write(pack("!H", header_size // 4))

        buf.write(header_data.getvalue())
        buf.write(transform_data.getvalue())
        buf.write(info_data.getvalue())

        # Pad out the header with 0x00
        for _ in range(0, padding_size, 1):
            buf.write(pack("!c", b"\0"))

        # Send data section
        buf.write(wout)

    def flushImpl(self, oneway):
        wout = self.__wbuf.getvalue()
        wout = self.transform(wout)
        wsz = len(wout)

        # reset wbuf before write/flush to preserve state on underlying failure
        self.__wbuf.seek(0)
        self.__wbuf.truncate()

        if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER:
            raise TTransportException(
                TTransportException.INVALID_CLIENT_TYPE,
                "Trying to send JSON encoding over binary",
            )

        buf = StringIO()
        if self.__client_type == CLIENT_TYPE.HEADER:
            self._flushHeaderMessage(buf, wout, wsz)
        elif self.__client_type in (
            CLIENT_TYPE.FRAMED_DEPRECATED,
            CLIENT_TYPE.FRAMED_COMPACT,
        ):
            buf.write(pack("!i", wsz))
            buf.write(wout)
        elif self.__client_type in (
            CLIENT_TYPE.UNFRAMED_DEPRECATED,
            CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED,
        ):
            buf.write(wout)
        elif self.__client_type == CLIENT_TYPE.HTTP_SERVER:
            # Reset the client type if we sent something -
            # oneway calls via HTTP expect a status response otherwise
            buf.write(self.header.getvalue())
            buf.write(wout)
            self.__client_type == CLIENT_TYPE.HEADER
        elif self.__client_type == CLIENT_TYPE.UNKNOWN:
            raise TTransportException(
                TTransportException.INVALID_CLIENT_TYPE, "Unknown client type"
            )

        # We don't include the framing bytes as part of the frame size check
        frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12)
        _frame_size_check(
            frame_size,
            self.__max_frame_size,
            header=self.__client_type == CLIENT_TYPE.HEADER,
        )
        self.getTransport().write(buf.getvalue())
        if oneway:
            self.getTransport().onewayFlush()
        else:
            self.getTransport().flush()

    # Implement the CReadableTransport interface.
    @property
    def cstringio_buf(self):
        if not self.__rbuf_frame:
            self.readFrame(0)
        return self.__rbuf

    def cstringio_refill(self, prefix, reqlen):
        # self.__rbuf will already be empty here because fastproto doesn't
        # ask for a refill until the previous buffer is empty.  Therefore,
        # we can start reading new frames immediately.

        # On unframed clients, there is a chance there is something left
        # in rbuf, and the read pointer is not advanced by fastproto
        # so seek to the end to be safe
        self.__rbuf.seek(0, 2)
        while len(prefix) < reqlen:
            prefix += self.read(reqlen)
        self.__rbuf = StringIO(prefix)
        return self.__rbuf


def _serialize_string(str_):
    if not isinstance(str_, bytes):
        str_ = str_.encode()
    return getVarint(len(str_)) + str_


def _flush_info_headers(info_data, write_headers, type) -> None:
    if len(write_headers) > 0:
        info_data.write(getVarint(type))
        info_data.write(getVarint(len(write_headers)))
        write_headers_iter = write_headers.items()
        for str_key, str_value in write_headers_iter:
            info_data.write(_serialize_string(str_key))
            info_data.write(_serialize_string(str_value))
        write_headers.clear()


def _read_string(bufio, buflimit):
    str_sz = readVarint(bufio)
    if str_sz + bufio.tell() > buflimit:
        raise TTransportException(
            TTransportException.INVALID_FRAME_SIZE, "String read too big"
        )
    return bufio.read(str_sz)


def _read_info_headers(data, end_header, read_headers) -> None:
    num_keys = readVarint(data)
    for _ in xrange(num_keys):
        str_key = _read_string(data, end_header)
        str_value = _read_string(data, end_header)
        read_headers[str_key] = str_value


def _frame_size_check(sz, set_max_size, header: bool = True) -> None:
    if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE):
        raise TTransportException(
            TTransportException.INVALID_FRAME_SIZE,
            "%s transport frame was too large" % "Header" if header else "Framed",
        )


class RequestHandler(http_server.BaseHTTPRequestHandler):

    # Same as superclass function, but append 'POST' because we
    # stripped it in the calling function.  Would be nice if
    # we had an ungetch instead
    def handle_one_request(self):
        self.raw_requestline = self.rfile.readline()
        if not self.raw_requestline:
            self.close_connection = 1
            return
        self.raw_requestline = "POST" + self.raw_requestline
        if not self.parse_request():
            # An error code has been sent, just exit
            return
        mname = "do_" + self.command
        if not hasattr(self, mname):
            self.send_error(501, "Unsupported method (%r)" % self.command)
            return
        method = getattr(self, mname)
        method()

    def setup(self):
        self.rfile = self.request
        self.wfile = StringIO()  # New output buffer

    def finish(self):
        if not self.rfile.closed:
            self.rfile.close()
        # leave wfile open for reading.

    def do_POST(self):
        if int(self.headers["Content-Length"]) > 0:
            self.data = self.rfile.read(int(self.headers["Content-Length"]))
        else:
            self.data = ""

        # Prepare a response header, to be sent later.
        self.send_response(200)
        self.send_header("content-type", "application/x-thrift")
        self.end_headers()
