thrift/lib/py/Thrift.py (245 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import sys
import threading
import six
UEXW_MAX_LENGTH = 1024
class TType:
STOP = 0
VOID = 1
BOOL = 2
BYTE = 3
I08 = 3
DOUBLE = 4
I16 = 6
I32 = 8
I64 = 10
STRING = 11
UTF7 = 11
STRUCT = 12
MAP = 13
SET = 14
LIST = 15
UTF8 = 16
UTF16 = 17
FLOAT = 19
class TMessageType:
CALL = 1
REPLY = 2
EXCEPTION = 3
ONEWAY = 4
class TPriority:
"""apache::thrift::concurrency::PRIORITY"""
HIGH_IMPORTANT = 0
HIGH = 1
IMPORTANT = 2
NORMAL = 3
BEST_EFFORT = 4
N_PRIORITIES = 5
class TRequestContext:
def __init__(self):
self._headers = None
def getHeaders(self):
return self._headers
def setHeaders(self, headers):
self._headers = headers
class TProcessorEventHandler:
"""Event handler for thrift processors"""
# TODO: implement asyncComplete for Twisted
def getHandlerContext(self, fn_name, server_context):
"""Called at the start of processing a handler method"""
return None
def preRead(self, handler_context, fn_name, args):
"""Called before the handler method's argument are read"""
pass
def postRead(self, handler_context, fn_name, args):
"""Called after the handler method's argument are read"""
pass
def preWrite(self, handler_context, fn_name, result):
"""Called before the handler method's results are written"""
pass
def postWrite(self, handler_context, fn_name, result):
"""Called after the handler method's results are written"""
pass
def handlerException(self, handler_context, fn_name, exception):
"""Called if (and only if) the handler threw an expected exception."""
pass
def handlerError(self, handler_context, fn_name, exception):
"""Called if (and only if) the handler threw an unexpected exception.
Note that this method is NOT called if the handler threw an
exception that is declared in the thrift service specification"""
logging.exception("Unexpected error in service handler " + fn_name + ":")
class TServerInterface:
def __init__(self):
self._tl_request_context = threading.local()
def setRequestContext(self, request_context):
self._tl_request_context.ctx = request_context
def getRequestContext(self):
return self._tl_request_context.ctx
class TProcessor:
"""Base class for processor, which works on two streams."""
def __init__(self):
self._event_handler = TProcessorEventHandler() # null object handler
self._handler = None
self._processMap = {}
self._priorityMap = {}
def setEventHandler(self, event_handler):
self._event_handler = event_handler
def getEventHandler(self):
return self._event_handler
def process(self, iprot, oprot, server_context=None):
pass
def onewayMethods(self):
return ()
def readMessageBegin(self, iprot):
name, _, seqid = iprot.readMessageBegin()
if six.PY3:
name = name.decode("utf8")
return name, seqid
def skipMessageStruct(self, iprot):
iprot.skip(TType.STRUCT)
iprot.readMessageEnd()
def doesKnowFunction(self, name):
return name in self._processMap
def callFunction(self, name, seqid, iprot, oprot, server_ctx):
process_fn = self._processMap[name]
return process_fn(self, seqid, iprot, oprot, server_ctx)
def readArgs(self, iprot, handler_ctx, fn_name, argtype):
args = argtype()
self._event_handler.preRead(handler_ctx, fn_name, args)
args.read(iprot)
iprot.readMessageEnd()
self._event_handler.postRead(handler_ctx, fn_name, args)
return args
def writeException(self, oprot, name, seqid, exc):
oprot.writeMessageBegin(name, TMessageType.EXCEPTION, seqid)
exc.write(oprot)
oprot.writeMessageEnd()
oprot.trans.flush()
def get_priority(self, fname):
return self._priorityMap.get(fname, TPriority.NORMAL)
def _getReplyType(self, result):
if isinstance(result, TApplicationException):
return TMessageType.EXCEPTION
return TMessageType.REPLY
@staticmethod
def _get_exception_from_thrift_result(result):
"""Returns the wrapped exception, if pressent. None if not.
result is a generated *_result object. This object either has a
'success' field set indicating the call succeeded, or a field set
indicating the exception thrown.
"""
fields = (
result.__dict__.keys() if hasattr(result, "__dict__") else result.__slots__
)
for field in fields:
value = getattr(result, field)
if value is None:
continue
elif field == "success":
return None
else:
return value
return None
def writeReply(self, oprot, handler_ctx, fn_name, seqid, result, server_ctx=None):
self._event_handler.preWrite(handler_ctx, fn_name, result)
reply_type = self._getReplyType(result)
if server_ctx is not None and hasattr(server_ctx, "context_data"):
ex = (
result
if reply_type == TMessageType.EXCEPTION
else self._get_exception_from_thrift_result(result)
)
if ex:
server_ctx.context_data.setHeaderEx(ex.__class__.__name__)
server_ctx.context_data.setHeaderExWhat(str(ex)[:UEXW_MAX_LENGTH])
try:
oprot.writeMessageBegin(fn_name, reply_type, seqid)
result.write(oprot)
oprot.writeMessageEnd()
oprot.trans.flush()
except Exception as e:
# Handle any thrift serialization exceptions
# Transport is likely in a messed up state. Some data may already have
# been written and it may not be possible to recover. Doing nothing
# causes the client to wait until the request times out. Try to
# close the connection to trigger a quicker failure on client side
oprot.trans.close()
# Let application know that there has been an exception
self._event_handler.handlerError(handler_ctx, fn_name, e)
# We raise the exception again to avoid any further processing
raise
finally:
# Since we called preWrite, we should also call postWrite to
# allow application to properly log their requests.
self._event_handler.postWrite(handler_ctx, fn_name, result)
class TException(Exception):
"""Base class for all thrift exceptions."""
# BaseException.message is deprecated in Python v[2.6,3.0)
if (2, 6, 0) <= sys.version_info < (3, 0):
def _get_message(self):
return self._message
def _set_message(self, message):
self._message = message
message = property(_get_message, _set_message)
def __init__(self, message=None):
Exception.__init__(self, message)
self.message = message
class TApplicationException(TException):
"""Application level thrift exceptions."""
UNKNOWN = 0
UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE = 2
WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
LOADSHEDDING = 11
TIMEOUT = 12
INJECTED_FAILURE = 13
EXTYPE_TO_STRING = {
UNKNOWN_METHOD: "Unknown method",
INVALID_MESSAGE_TYPE: "Invalid message type",
WRONG_METHOD_NAME: "Wrong method name",
BAD_SEQUENCE_ID: "Bad sequence ID",
MISSING_RESULT: "Missing result",
INTERNAL_ERROR: "Internal error",
PROTOCOL_ERROR: "Protocol error",
INVALID_TRANSFORM: "Invalid transform",
INVALID_PROTOCOL: "Invalid protocol",
UNSUPPORTED_CLIENT_TYPE: "Unsupported client type",
LOADSHEDDING: "Loadshedding request",
TIMEOUT: "Task timeout",
INJECTED_FAILURE: "Injected Failure",
}
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
self.type = type
def __str__(self):
if self.message:
return self.message
else:
return self.EXTYPE_TO_STRING.get(
self.type, "Default (unknown) TApplicationException"
)
def read(self, iprot):
iprot.readStructBegin()
while True:
(fname, ftype, fid) = iprot.readFieldBegin()
if ftype == TType.STOP:
break
if fid == 1:
if ftype == TType.STRING:
message = iprot.readString()
if sys.version_info.major >= 3 and isinstance(message, bytes):
try:
message = message.decode("utf-8")
except UnicodeDecodeError:
pass
self.message = message
else:
iprot.skip(ftype)
elif fid == 2:
if ftype == TType.I32:
self.type = iprot.readI32()
else:
iprot.skip(ftype)
else:
iprot.skip(ftype)
iprot.readFieldEnd()
iprot.readStructEnd()
def write(self, oprot):
oprot.writeStructBegin(b"TApplicationException")
if self.message is not None:
oprot.writeFieldBegin(b"message", TType.STRING, 1)
oprot.writeString(
self.message.encode("utf-8")
if not isinstance(self.message, bytes)
else self.message
)
oprot.writeFieldEnd()
if self.type is not None:
oprot.writeFieldBegin(b"type", TType.I32, 2)
oprot.writeI32(self.type)
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()
class UnimplementedTypedef:
pass