thrift/lib/py/server/TServer.py (99 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-unsafe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
from thrift.protocol import TBinaryProtocol
from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.Thrift import TProcessor, TApplicationException
from thrift.transport import TTransport
class TConnectionContext:
def getPeerName(self):
"""Gets the address of the client.
Returns:
The equivalent value of socket.getpeername() on the client socket
"""
raise NotImplementedError
class TRpcConnectionContext(TConnectionContext):
"""Connection context class for thrift RPC calls"""
def __init__(self, client_socket, iprot=None, oprot=None):
"""Initializer.
Arguments:
client_socket: the TSocket to the client
"""
self._client_socket = client_socket
self.iprot = iprot
self.oprot = oprot
def setProtocols(self, iprot, oprot):
self.iprot = iprot
self.oprot = oprot
def getPeerName(self):
"""Gets the address of the client.
Returns:
Same value as socket.peername() for the TSocket
"""
return self._client_socket.getPeerName()
def getSockName(self):
"""Gets the address of the server.
Returns:
Same value as socket.getsockname() for the TSocket
"""
return self._client_socket.getsockname()
class TServerEventHandler:
"""Event handler base class.
Override selected methods on this class to implement custom event handling
"""
def preServe(self, address):
"""Called before the server begins.
Arguments:
address: the address that the server is listening on
"""
pass
def newConnection(self, context):
"""Called when a client has connected and is about to begin processing.
Arguments:
context: instance of TRpcConnectionContext
"""
pass
def clientBegin(self, iprot, oprot):
"""Deprecated: Called when a new connection is made to the server.
For all servers other than TNonblockingServer, this function is called
whenever newConnection is called and vice versa. This is the old-style
for event handling and is not supported for TNonblockingServer. New
code should always use the newConnection method.
"""
pass
def connectionDestroyed(self, context):
"""Called when a client has finished request-handling.
Arguments:
context: instance of TRpcConnectionContext
"""
pass
class TServer:
"""Base interface for a server, which must have a serve method."""
""" constructors for all servers:
1) (processor, serverTransport)
2) (processor, serverTransport, transportFactory, protocolFactory)
3) (processor, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory)
Optionally, the handler can be passed instead of the processor,
and a processor will be created automatically:
4) (handler, serverTransport)
5) (handler, serverTransport, transportFacotry, protocolFactory)
6) (handler, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory)
The attribute serverEventHandler (default: None) receives
callbacks for various events in the server lifecycle. It should
be set to an instance of TServerEventHandler.
"""
def __init__(self, *args):
if len(args) == 2:
self.__initArgs__(
args[0],
args[1],
TTransport.TTransportFactoryBase(),
TTransport.TTransportFactoryBase(),
TBinaryProtocol.TBinaryProtocolFactory(),
TBinaryProtocol.TBinaryProtocolFactory(),
)
elif len(args) == 4:
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
elif len(args) == 6:
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
def __initArgs__(
self,
processor,
serverTransport,
inputTransportFactory,
outputTransportFactory,
inputProtocolFactory,
outputProtocolFactory,
):
self.processor = self._getProcessor(processor)
self.serverTransport = serverTransport
self.inputTransportFactory = inputTransportFactory
self.outputTransportFactory = outputTransportFactory
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
self.serverEventHandler = TServerEventHandler()
def _getProcessor(self, processor):
"""Check if a processor is really a processor, or if it is a handler
auto create a processor for it"""
if isinstance(processor, TProcessor):
return processor
elif hasattr(processor, "_processor_type"):
handler = processor
return handler._processor_type(handler)
else:
raise TApplicationException(message="Could not detect processor type")
def setServerEventHandler(self, handler):
self.serverEventHandler = handler
def _clientBegin(self, context, iprot, oprot):
self.serverEventHandler.newConnection(context)
self.serverEventHandler.clientBegin(iprot, oprot)
def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
oprot = iprot
else:
oprot = self.outputProtocolFactory.getProtocol(otrans)
context = TRpcConnectionContext(client, iprot, oprot)
self._clientBegin(context, iprot, oprot)
try:
while True:
self.processor.process(iprot, oprot, context)
except TTransport.TTransportException:
pass
except Exception as x:
logging.exception(x)
self.serverEventHandler.connectionDestroyed(context)
itrans.close()
otrans.close()
def serve(self):
pass