thrift/lib/py/transport/THeaderTransport.py (480 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-unsafe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import http.server as http_server
import zlib
from io import BytesIO as StringIO
from struct import pack, unpack
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TCompactProtocol import getVarint, readVarint, TCompactProtocol
from thrift.Thrift import TApplicationException
from thrift.transport.TTransport import (
TTransportException,
TTransportBase,
CReadableTransport,
)
xrange = range
# Import the snappy module if it is available
try:
import snappy
except ImportError:
# If snappy is not available, don't fail immediately.
# Only raise an error if we actually ever need to perform snappy
# compression.
class DummySnappy(object):
def compress(self, buf):
raise TTransportException(
TTransportException.INVALID_TRANSFORM, "snappy module not available"
)
def decompress(self, buf):
raise TTransportException(
TTransportException.INVALID_TRANSFORM, "snappy module not available"
)
snappy = DummySnappy()
# Import the zstd module if it is available
try:
import zstd # @manual
except ImportError:
# If zstd is not available, don't fail immediately.
# Only raise an error if we actually ever need to perform zstd
# compression.
class DummyZstd(object):
def ZstdCompressor(self, write_content_size):
raise TTransportException(
TTransportException.INVALID_TRANSFORM, "zstd module not available"
)
def ZstdDecompressor(self):
raise TTransportException(
TTransportException.INVALID_TRANSFORM, "zstd module not available"
)
zstd = DummyZstd()
# Definitions from THeader.h
class CLIENT_TYPE:
HEADER = 0
FRAMED_DEPRECATED = 1
UNFRAMED_DEPRECATED = 2
HTTP_SERVER = 3
HTTP_CLIENT = 4
FRAMED_COMPACT = 5
HTTP_GET = 7
UNKNOWN = 8
UNFRAMED_COMPACT_DEPRECATED = 9
class HEADER_FLAG:
SUPPORT_OUT_OF_ORDER = 0x01
DUPLEX_REVERSE = 0x08
class TRANSFORM:
NONE = 0x00
ZLIB = 0x01
HMAC = 0x02
SNAPPY = 0x03
QLZ = 0x04
ZSTD = 0x05
class INFO:
NORMAL = 1
PERSISTENT = 2
T_BINARY_PROTOCOL = 0
T_COMPACT_PROTOCOL = 2
HEADER_MAGIC = 0x0FFF0000
PACKED_HEADER_MAGIC: bytes = pack(b"!H", HEADER_MAGIC >> 16)
HEADER_MASK = 0xFFFF0000
FLAGS_MASK = 0x0000FFFF
HTTP_SERVER_MAGIC = 0x504F5354 # POST
HTTP_CLIENT_MAGIC = 0x48545450 # HTTP
HTTP_GET_CLIENT_MAGIC = 0x47455420 # GET
HTTP_HEAD_CLIENT_MAGIC = 0x48454144 # HEAD
BIG_FRAME_MAGIC = 0x42494746 # BIGF
MAX_FRAME_SIZE = 0x3FFFFFFF
MAX_BIG_FRAME_SIZE = 2 ** 61 - 1
class THeaderTransport(TTransportBase, CReadableTransport):
"""Transport that sends headers. Also understands framed/unframed/HTTP
transports and will do the right thing"""
__max_frame_size = MAX_FRAME_SIZE
# Defaults to current user, but there is also a setter below.
__identity = None
__first_request = True
IDENTITY_HEADER = "identity"
ID_VERSION_HEADER = "id_version"
ID_VERSION = "1"
CLIENT_METADATA_HEADER = "client_metadata"
def __init__(self, trans, client_types=None, client_type=None):
self.__trans = trans
self.__rbuf = StringIO()
self.__rbuf_frame = False
self.__wbuf = StringIO()
self.seq_id = 0
self.__flags = 0
self.__read_transforms = []
self.__write_transforms = []
self.__supported_client_types = set(client_types or (CLIENT_TYPE.HEADER,))
self.__proto_id = T_COMPACT_PROTOCOL # default to compact like c++
self.__client_type = client_type or CLIENT_TYPE.HEADER
self.__read_headers = {}
self.__read_persistent_headers = {}
self.__write_headers = {}
self.__write_persistent_headers = {}
if self.__client_type in (
CLIENT_TYPE.UNFRAMED_DEPRECATED,
CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED,
CLIENT_TYPE.FRAMED_DEPRECATED,
CLIENT_TYPE.FRAMED_COMPACT,
):
self.__client_type = CLIENT_TYPE.HEADER
self.__supported_client_types.add(self.__client_type)
# If we support unframed binary / framed binary also support compact
if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types:
self.__supported_client_types.add(CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED)
if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types:
self.__supported_client_types.add(CLIENT_TYPE.FRAMED_COMPACT)
def set_header_flag(self, flag):
self.__flags |= flag
def clear_header_flag(self, flag):
self.__flags &= ~flag
def header_flags(self):
return self.__flags
def set_max_frame_size(self, size):
if size > MAX_BIG_FRAME_SIZE:
raise TTransportException(
TTransportException.INVALID_FRAME_SIZE,
"Cannot set max frame size > %s" % MAX_BIG_FRAME_SIZE,
)
if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER:
raise TTransportException(
TTransportException.INVALID_FRAME_SIZE,
"Cannot set max frame size > %s for clients other than HEADER"
% MAX_FRAME_SIZE,
)
self.__max_frame_size = size
def get_peer_identity(self):
if self.IDENTITY_HEADER in self.__read_headers:
if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION:
return self.__read_headers[self.IDENTITY_HEADER]
return None
def set_identity(self, identity):
self.__identity = identity
def get_protocol_id(self):
return self.__proto_id
def set_protocol_id(self, proto_id):
self.__proto_id = proto_id
def set_header(self, str_key, str_value):
self.__write_headers[str_key] = str_value
def get_write_headers(self):
return self.__write_headers
def get_headers(self):
return self.__read_headers
def clear_headers(self):
self.__write_headers.clear()
def set_persistent_header(self, str_key, str_value):
self.__write_persistent_headers[str_key] = str_value
def get_write_persistent_headers(self):
return self.__write_persistent_headers
def clear_persistent_headers(self):
self.__write_persistent_headers.clear()
def add_transform(self, trans_id):
self.__write_transforms.append(trans_id)
def _reset_protocol(self):
# HTTP calls that are one way need to flush here.
if self.__client_type == CLIENT_TYPE.HTTP_SERVER:
self.flush()
# set to anything except unframed
self.__client_type = CLIENT_TYPE.UNKNOWN
# Read header bytes to check which protocol to decode
self.readFrame(0)
def getTransport(self):
return self.__trans
def isOpen(self):
return self.getTransport().isOpen()
def open(self):
return self.getTransport().open()
def close(self):
return self.getTransport().close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) == sz:
return ret
if self.__client_type in (
CLIENT_TYPE.UNFRAMED_DEPRECATED,
CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED,
):
return ret + self.getTransport().readAll(sz - len(ret))
self.readFrame(sz - len(ret))
return ret + self.__rbuf.read(sz - len(ret))
readAll = read # TTransportBase.readAll does a needless copy here.
def readFrame(self, req_sz):
self.__rbuf_frame = True
word1 = self.getTransport().readAll(4)
sz = unpack("!I", word1)[0]
proto_id = word1[0]
if proto_id == TBinaryProtocol.PROTOCOL_ID:
# unframed
self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED
self.__proto_id = T_BINARY_PROTOCOL
if req_sz <= 4: # check for reads < 0.
self.__rbuf = StringIO(word1)
else:
self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4))
elif proto_id == TCompactProtocol.PROTOCOL_ID:
self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED
self.__proto_id = T_COMPACT_PROTOCOL
if req_sz <= 4: # check for reads < 0.
self.__rbuf = StringIO(word1)
else:
self.__rbuf = StringIO(word1 + self.getTransport().read(req_sz - 4))
elif sz == HTTP_SERVER_MAGIC:
self.__client_type = CLIENT_TYPE.HTTP_SERVER
mf = self.getTransport().handle.makefile("rb", -1)
self.handler = RequestHandler(mf, "client_address:port", "")
self.header = self.handler.wfile
self.__rbuf = StringIO(self.handler.data)
else:
if sz == BIG_FRAME_MAGIC:
sz = unpack("!Q", self.getTransport().readAll(8))[0]
# could be header format or framed. Check next two bytes.
magic = self.getTransport().readAll(2)
proto_id = magic[0]
if proto_id == TCompactProtocol.PROTOCOL_ID:
self.__client_type = CLIENT_TYPE.FRAMED_COMPACT
self.__proto_id = T_COMPACT_PROTOCOL
_frame_size_check(sz, self.__max_frame_size, header=False)
self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2))
elif proto_id == TBinaryProtocol.PROTOCOL_ID:
self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED
self.__proto_id = T_BINARY_PROTOCOL
_frame_size_check(sz, self.__max_frame_size, header=False)
self.__rbuf = StringIO(magic + self.getTransport().readAll(sz - 2))
elif magic == PACKED_HEADER_MAGIC:
self.__client_type = CLIENT_TYPE.HEADER
_frame_size_check(sz, self.__max_frame_size)
# flags(2), seq_id(4), header_size(2)
n_header_meta = self.getTransport().readAll(8)
self.__flags, self.seq_id, header_size = unpack("!HIH", n_header_meta)
data = StringIO()
data.write(magic)
data.write(n_header_meta)
data.write(self.getTransport().readAll(sz - 10))
data.seek(10)
self.read_header_format(sz - 10, header_size, data)
else:
self.__client_type = CLIENT_TYPE.UNKNOWN
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Could not detect client transport type",
)
if self.__client_type not in self.__supported_client_types:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Client type {} not supported on server".format(self.__client_type),
)
def read_header_format(self, sz, header_size, data):
# clear out any previous transforms
self.__read_transforms = []
header_size = header_size * 4
if header_size > sz:
raise TTransportException(
TTransportException.INVALID_FRAME_SIZE,
"Header size is larger than frame",
)
end_header = header_size + data.tell()
self.__proto_id = readVarint(data)
num_headers = readVarint(data)
if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Trying to recv JSON encoding over binary",
)
# Read the headers. Data for each header varies.
for _ in range(0, num_headers):
trans_id = readVarint(data)
if trans_id in (TRANSFORM.ZLIB, TRANSFORM.SNAPPY, TRANSFORM.ZSTD):
self.__read_transforms.insert(0, trans_id)
elif trans_id == TRANSFORM.HMAC:
raise TApplicationException(
TApplicationException.INVALID_TRANSFORM,
"Hmac transform is no longer supported: %i" % trans_id,
)
else:
# TApplicationException will be sent back to client
raise TApplicationException(
TApplicationException.INVALID_TRANSFORM,
"Unknown transform in client request: %i" % trans_id,
)
# Clear out previous info headers.
self.__read_headers.clear()
# Read the info headers.
while data.tell() < end_header:
info_id = readVarint(data)
if info_id == INFO.NORMAL:
_read_info_headers(data, end_header, self.__read_headers)
elif info_id == INFO.PERSISTENT:
_read_info_headers(data, end_header, self.__read_persistent_headers)
else:
break # Unknown header. Stop info processing.
if self.__read_persistent_headers:
self.__read_headers.update(self.__read_persistent_headers)
# Skip the rest of the header
data.seek(end_header)
payload = data.read(sz - header_size)
# Read the data section.
self.__rbuf = StringIO(self.untransform(payload))
def write(self, buf):
self.__wbuf.write(buf)
def transform(self, buf):
for trans_id in self.__write_transforms:
if trans_id == TRANSFORM.ZLIB:
buf = zlib.compress(buf)
elif trans_id == TRANSFORM.SNAPPY:
buf = snappy.compress(buf)
elif trans_id == TRANSFORM.ZSTD:
buf = zstd.ZstdCompressor(write_content_size=True).compress(buf)
else:
raise TTransportException(
TTransportException.INVALID_TRANSFORM,
"Unknown transform during send",
)
return buf
def untransform(self, buf):
for trans_id in self.__read_transforms:
if trans_id == TRANSFORM.ZLIB:
buf = zlib.decompress(buf)
elif trans_id == TRANSFORM.SNAPPY:
buf = snappy.decompress(buf)
elif trans_id == TRANSFORM.ZSTD:
buf = zstd.ZstdDecompressor().decompress(buf)
if trans_id not in self.__write_transforms:
self.__write_transforms.append(trans_id)
return buf
def disable_client_metadata(self):
self.__first_request = False
def flush(self):
self.flushImpl(False)
def onewayFlush(self):
self.flushImpl(True)
def _flushHeaderMessage(self, buf, wout, wsz):
"""Write a message for CLIENT_TYPE.HEADER
@param buf(StringIO): Buffer to write message to
@param wout(str): Payload
@param wsz(int): Payload length
"""
transform_data = StringIO()
# For now, all transforms don't require data.
num_transforms = len(self.__write_transforms)
for trans_id in self.__write_transforms:
transform_data.write(getVarint(trans_id))
# Add in special flags.
if self.__identity:
self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION
self.__write_headers[self.IDENTITY_HEADER] = self.__identity
if self.__first_request:
self.__first_request = False
self.__write_headers[
self.CLIENT_METADATA_HEADER
] = '{"agent":"THeaderTransport.py"}'
info_data = StringIO()
# Write persistent kv-headers
_flush_info_headers(
info_data, self.get_write_persistent_headers(), INFO.PERSISTENT
)
# Write non-persistent kv-headers
_flush_info_headers(info_data, self.__write_headers, INFO.NORMAL)
header_data = StringIO()
header_data.write(getVarint(self.__proto_id))
header_data.write(getVarint(num_transforms))
header_size = transform_data.tell() + header_data.tell() + info_data.tell()
padding_size = 4 - (header_size % 4)
header_size = header_size + padding_size
# MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2)
wsz += header_size + 10
if wsz > MAX_FRAME_SIZE:
buf.write(pack("!I", BIG_FRAME_MAGIC))
buf.write(pack("!Q", wsz))
else:
buf.write(pack("!I", wsz))
buf.write(pack("!HH", HEADER_MAGIC >> 16, self.__flags))
buf.write(pack("!I", self.seq_id))
buf.write(pack("!H", header_size // 4))
buf.write(header_data.getvalue())
buf.write(transform_data.getvalue())
buf.write(info_data.getvalue())
# Pad out the header with 0x00
for _ in range(0, padding_size, 1):
buf.write(pack("!c", b"\0"))
# Send data section
buf.write(wout)
def flushImpl(self, oneway):
wout = self.__wbuf.getvalue()
wout = self.transform(wout)
wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf.seek(0)
self.__wbuf.truncate()
if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Trying to send JSON encoding over binary",
)
buf = StringIO()
if self.__client_type == CLIENT_TYPE.HEADER:
self._flushHeaderMessage(buf, wout, wsz)
elif self.__client_type in (
CLIENT_TYPE.FRAMED_DEPRECATED,
CLIENT_TYPE.FRAMED_COMPACT,
):
buf.write(pack("!i", wsz))
buf.write(wout)
elif self.__client_type in (
CLIENT_TYPE.UNFRAMED_DEPRECATED,
CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED,
):
buf.write(wout)
elif self.__client_type == CLIENT_TYPE.HTTP_SERVER:
# Reset the client type if we sent something -
# oneway calls via HTTP expect a status response otherwise
buf.write(self.header.getvalue())
buf.write(wout)
self.__client_type == CLIENT_TYPE.HEADER
elif self.__client_type == CLIENT_TYPE.UNKNOWN:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE, "Unknown client type"
)
# We don't include the framing bytes as part of the frame size check
frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12)
_frame_size_check(
frame_size,
self.__max_frame_size,
header=self.__client_type == CLIENT_TYPE.HEADER,
)
self.getTransport().write(buf.getvalue())
if oneway:
self.getTransport().onewayFlush()
else:
self.getTransport().flush()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
if not self.__rbuf_frame:
self.readFrame(0)
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastproto doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
# On unframed clients, there is a chance there is something left
# in rbuf, and the read pointer is not advanced by fastproto
# so seek to the end to be safe
self.__rbuf.seek(0, 2)
while len(prefix) < reqlen:
prefix += self.read(reqlen)
self.__rbuf = StringIO(prefix)
return self.__rbuf
def _serialize_string(str_):
if not isinstance(str_, bytes):
str_ = str_.encode()
return getVarint(len(str_)) + str_
def _flush_info_headers(info_data, write_headers, type) -> None:
if len(write_headers) > 0:
info_data.write(getVarint(type))
info_data.write(getVarint(len(write_headers)))
write_headers_iter = write_headers.items()
for str_key, str_value in write_headers_iter:
info_data.write(_serialize_string(str_key))
info_data.write(_serialize_string(str_value))
write_headers.clear()
def _read_string(bufio, buflimit):
str_sz = readVarint(bufio)
if str_sz + bufio.tell() > buflimit:
raise TTransportException(
TTransportException.INVALID_FRAME_SIZE, "String read too big"
)
return bufio.read(str_sz)
def _read_info_headers(data, end_header, read_headers) -> None:
num_keys = readVarint(data)
for _ in xrange(num_keys):
str_key = _read_string(data, end_header)
str_value = _read_string(data, end_header)
read_headers[str_key] = str_value
def _frame_size_check(sz, set_max_size, header: bool = True) -> None:
if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE):
raise TTransportException(
TTransportException.INVALID_FRAME_SIZE,
"%s transport frame was too large" % "Header" if header else "Framed",
)
class RequestHandler(http_server.BaseHTTPRequestHandler):
# Same as superclass function, but append 'POST' because we
# stripped it in the calling function. Would be nice if
# we had an ungetch instead
def handle_one_request(self):
self.raw_requestline = self.rfile.readline()
if not self.raw_requestline:
self.close_connection = 1
return
self.raw_requestline = "POST" + self.raw_requestline
if not self.parse_request():
# An error code has been sent, just exit
return
mname = "do_" + self.command
if not hasattr(self, mname):
self.send_error(501, "Unsupported method (%r)" % self.command)
return
method = getattr(self, mname)
method()
def setup(self):
self.rfile = self.request
self.wfile = StringIO() # New output buffer
def finish(self):
if not self.rfile.closed:
self.rfile.close()
# leave wfile open for reading.
def do_POST(self):
if int(self.headers["Content-Length"]) > 0:
self.data = self.rfile.read(int(self.headers["Content-Length"]))
else:
self.data = ""
# Prepare a response header, to be sent later.
self.send_response(200)
self.send_header("content-type", "application/x-thrift")
self.end_headers()