thrift/lib/py/transport/THttpClient.py (107 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-unsafe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import http.client as http_client
import os
import socket
import sys
import warnings
from io import BytesIO as StringIO
from urllib import parse
from thrift.transport.TTransport import TTransportBase, TTransportException
class THttpClient(TTransportBase):
"""Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None, ssl_context=None):
"""THttpClient supports two different types constructor parameters.
THttpClient(host, port, path) - deprecated
THttpClient(uri)
Only the second supports https."""
if port is not None:
warnings.warn(
"Please use the THttpClient('http://host:port/path') syntax",
DeprecationWarning,
stacklevel=2,
)
self.host = uri_or_host
self.http_host = self.host
self.port = port
assert path
self.path = path
self.scheme = "http"
else:
parsed = parse.urlparse(uri_or_host)
self.scheme = parsed.scheme
assert self.scheme in ("http", "https")
if self.scheme == "http":
self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == "https":
self.port = parsed.port or http_client.HTTPS_PORT
self.host = parsed.hostname
self.http_host = parsed.netloc
self.path = parsed.path
if parsed.query:
self.path += "?%s" % parsed.query
self.__wbuf = StringIO()
self.__http = None
self.__timeout = None
self.__custom_headers = None
self.ssl_context = ssl_context
def open(self):
if self.scheme == "http":
self.__http = http_client.HTTPConnection(
self.host, self.port, timeout=self.__timeout
)
else:
self.__http = http_client.HTTPSConnection(
self.host, self.port, context=self.ssl_context, timeout=self.__timeout
)
def close(self):
self.__http.close()
self.__http = None
def isOpen(self):
return self.__http is not None
def setTimeout(self, ms):
if ms is None:
self.__timeout = None
else:
self.__timeout = ms / 1000.0
def setCustomHeaders(self, headers):
self.__custom_headers = headers
def setCustomHeader(self, name, value):
if self.__custom_headers is None:
self.__custom_headers = {}
self.__custom_headers[name] = value
def read(self, sz):
return self.response.read(sz)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
if self.isOpen():
self.close()
self.open()
# Pull data out of buffer
data = self.__wbuf.getvalue()
self.__wbuf = StringIO()
# HTTP request
self.__http.putrequest("POST", self.path, skip_host=True)
if not self.__custom_headers or "Host" not in self.__custom_headers:
self.__http.putheader("Host", self.http_host)
self.__http.putheader("Content-Type", "application/x-thrift")
self.__http.putheader("Content-Length", str(len(data)))
if not self.__custom_headers or "User-Agent" not in self.__custom_headers:
user_agent = "Python/THttpClient"
script = os.path.basename(sys.argv[0])
if script:
user_agent = "%s (%s)" % (user_agent, parse.quote(script))
self.__http.putheader("User-Agent", user_agent)
if self.__custom_headers:
if sys.version_info[0] >= 3:
custom_headers_iter = self.__custom_headers.items()
else:
custom_headers_iter = self.__custom_headers.items()
for key, val in custom_headers_iter:
self.__http.putheader(key, val)
try:
self.__http.endheaders()
# Write payload
self.__http.send(data)
except socket.gaierror as e:
raise TTransportException(TTransportException.NOT_OPEN, str(e))
except Exception as e:
raise TTransportException(TTransportException.UNKNOWN, str(e))
# Get reply to flush the request
self.response = self.__http.getresponse()
self.code = self.response.status
self.headers = self.response.getheaders()