# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import logging

from thrift.protocol import TBinaryProtocol
from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.Thrift import TProcessor, TApplicationException
from thrift.transport import TTransport


class TConnectionContext:
    def getPeerName(self):
        """Gets the address of the client.

        Returns:
          The equivalent value of socket.getpeername() on the client socket
        """
        raise NotImplementedError


class TRpcConnectionContext(TConnectionContext):
    """Connection context class for thrift RPC calls"""

    def __init__(self, client_socket, iprot=None, oprot=None):
        """Initializer.

        Arguments:
          client_socket: the TSocket to the client
        """
        self._client_socket = client_socket
        self.iprot = iprot
        self.oprot = oprot

    def setProtocols(self, iprot, oprot):
        self.iprot = iprot
        self.oprot = oprot

    def getPeerName(self):
        """Gets the address of the client.

        Returns:
          Same value as socket.peername() for the TSocket
        """
        return self._client_socket.getPeerName()

    def getSockName(self):
        """Gets the address of the server.

        Returns:
          Same value as socket.getsockname() for the TSocket
        """
        return self._client_socket.getsockname()


class TServerEventHandler:
    """Event handler base class.

    Override selected methods on this class to implement custom event handling
    """

    def preServe(self, address):
        """Called before the server begins.

        Arguments:
          address: the address that the server is listening on
        """
        pass

    def newConnection(self, context):
        """Called when a client has connected and is about to begin processing.

        Arguments:
          context: instance of TRpcConnectionContext
        """
        pass

    def clientBegin(self, iprot, oprot):
        """Deprecated: Called when a new connection is made to the server.

        For all servers other than TNonblockingServer, this function is called
        whenever newConnection is called and vice versa.  This is the old-style
        for event handling and is not supported for TNonblockingServer. New
        code should always use the newConnection method.
        """
        pass

    def connectionDestroyed(self, context):
        """Called when a client has finished request-handling.

        Arguments:
          context: instance of TRpcConnectionContext
        """
        pass


class TServer:

    """Base interface for a server, which must have a serve method."""

    """ constructors for all servers:
    1) (processor, serverTransport)
    2) (processor, serverTransport, transportFactory, protocolFactory)
    3) (processor, serverTransport,
        inputTransportFactory, outputTransportFactory,
        inputProtocolFactory, outputProtocolFactory)

        Optionally, the handler can be passed instead of the processor,
        and a processor will be created automatically:

    4) (handler, serverTransport)
    5) (handler, serverTransport, transportFacotry, protocolFactory)
    6) (handler, serverTransport,
        inputTransportFactory, outputTransportFactory,
        inputProtocolFactory, outputProtocolFactory)

        The attribute serverEventHandler (default: None) receives
        callbacks for various events in the server lifecycle.  It should
        be set to an instance of TServerEventHandler.

        """

    def __init__(self, *args):
        if len(args) == 2:
            self.__initArgs__(
                args[0],
                args[1],
                TTransport.TTransportFactoryBase(),
                TTransport.TTransportFactoryBase(),
                TBinaryProtocol.TBinaryProtocolFactory(),
                TBinaryProtocol.TBinaryProtocolFactory(),
            )
        elif len(args) == 4:
            self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
        elif len(args) == 6:
            self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])

    def __initArgs__(
        self,
        processor,
        serverTransport,
        inputTransportFactory,
        outputTransportFactory,
        inputProtocolFactory,
        outputProtocolFactory,
    ):
        self.processor = self._getProcessor(processor)
        self.serverTransport = serverTransport
        self.inputTransportFactory = inputTransportFactory
        self.outputTransportFactory = outputTransportFactory
        self.inputProtocolFactory = inputProtocolFactory
        self.outputProtocolFactory = outputProtocolFactory

        self.serverEventHandler = TServerEventHandler()

    def _getProcessor(self, processor):
        """Check if a processor is really a processor, or if it is a handler
        auto create a processor for it"""
        if isinstance(processor, TProcessor):
            return processor
        elif hasattr(processor, "_processor_type"):
            handler = processor
            return handler._processor_type(handler)
        else:
            raise TApplicationException(message="Could not detect processor type")

    def setServerEventHandler(self, handler):
        self.serverEventHandler = handler

    def _clientBegin(self, context, iprot, oprot):
        self.serverEventHandler.newConnection(context)
        self.serverEventHandler.clientBegin(iprot, oprot)

    def handle(self, client):
        itrans = self.inputTransportFactory.getTransport(client)
        otrans = self.outputTransportFactory.getTransport(client)
        iprot = self.inputProtocolFactory.getProtocol(itrans)

        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
            oprot = iprot
        else:
            oprot = self.outputProtocolFactory.getProtocol(otrans)

        context = TRpcConnectionContext(client, iprot, oprot)
        self._clientBegin(context, iprot, oprot)

        try:
            while True:
                self.processor.process(iprot, oprot, context)
        except TTransport.TTransportException:
            pass
        except Exception as x:
            logging.exception(x)

        self.serverEventHandler.connectionDestroyed(context)
        itrans.close()
        otrans.close()

    def serve(self):
        pass
