# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import http.client as http_client
import os
import socket
import sys
import warnings
from io import BytesIO as StringIO
from urllib import parse

from thrift.transport.TTransport import TTransportBase, TTransportException


class THttpClient(TTransportBase):

    """Http implementation of TTransport base."""

    def __init__(self, uri_or_host, port=None, path=None, ssl_context=None):
        """THttpClient supports two different types constructor parameters.

        THttpClient(host, port, path) - deprecated
        THttpClient(uri)

        Only the second supports https."""

        if port is not None:
            warnings.warn(
                "Please use the THttpClient('http://host:port/path') syntax",
                DeprecationWarning,
                stacklevel=2,
            )
            self.host = uri_or_host
            self.http_host = self.host
            self.port = port
            assert path
            self.path = path
            self.scheme = "http"
        else:
            parsed = parse.urlparse(uri_or_host)
            self.scheme = parsed.scheme
            assert self.scheme in ("http", "https")
            if self.scheme == "http":
                self.port = parsed.port or http_client.HTTP_PORT
            elif self.scheme == "https":
                self.port = parsed.port or http_client.HTTPS_PORT
            self.host = parsed.hostname
            self.http_host = parsed.netloc
            self.path = parsed.path
            if parsed.query:
                self.path += "?%s" % parsed.query
        self.__wbuf = StringIO()
        self.__http = None
        self.__timeout = None
        self.__custom_headers = None
        self.ssl_context = ssl_context

    def open(self):
        if self.scheme == "http":
            self.__http = http_client.HTTPConnection(
                self.host, self.port, timeout=self.__timeout
            )
        else:
            self.__http = http_client.HTTPSConnection(
                self.host, self.port, context=self.ssl_context, timeout=self.__timeout
            )

    def close(self):
        self.__http.close()
        self.__http = None

    def isOpen(self):
        return self.__http is not None

    def setTimeout(self, ms):
        if ms is None:
            self.__timeout = None
        else:
            self.__timeout = ms / 1000.0

    def setCustomHeaders(self, headers):
        self.__custom_headers = headers

    def setCustomHeader(self, name, value):
        if self.__custom_headers is None:
            self.__custom_headers = {}
        self.__custom_headers[name] = value

    def read(self, sz):
        return self.response.read(sz)

    def write(self, buf):
        self.__wbuf.write(buf)

    def flush(self):
        if self.isOpen():
            self.close()
        self.open()

        # Pull data out of buffer
        data = self.__wbuf.getvalue()
        self.__wbuf = StringIO()

        # HTTP request
        self.__http.putrequest("POST", self.path, skip_host=True)

        if not self.__custom_headers or "Host" not in self.__custom_headers:
            self.__http.putheader("Host", self.http_host)

        self.__http.putheader("Content-Type", "application/x-thrift")
        self.__http.putheader("Content-Length", str(len(data)))

        if not self.__custom_headers or "User-Agent" not in self.__custom_headers:
            user_agent = "Python/THttpClient"
            script = os.path.basename(sys.argv[0])
            if script:
                user_agent = "%s (%s)" % (user_agent, parse.quote(script))
            self.__http.putheader("User-Agent", user_agent)

        if self.__custom_headers:
            if sys.version_info[0] >= 3:
                custom_headers_iter = self.__custom_headers.items()
            else:
                custom_headers_iter = self.__custom_headers.items()
            for key, val in custom_headers_iter:
                self.__http.putheader(key, val)

        try:
            self.__http.endheaders()

            # Write payload
            self.__http.send(data)
        except socket.gaierror as e:
            raise TTransportException(TTransportException.NOT_OPEN, str(e))
        except Exception as e:
            raise TTransportException(TTransportException.UNKNOWN, str(e))

        # Get reply to flush the request
        self.response = self.__http.getresponse()
        self.code = self.response.status
        self.headers = self.response.getheaders()
