# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import errno
import os
import select
import socket
import sys
import time
import warnings

from thrift.transport.TTransport import (
    TTransportBase,
    TTransportException,
    TServerTransportBase,
)

try:
    import fcntl
except ImportError:
    # Windows doesn't have this module
    fcntl = None


def py2_compatible_process_time():
    if sys.version_info.major >= 3 and sys.version_info.minor >= 3:
        return time.process_time()
    else:
        return time.clock()


class ConnectionEpoll:
    """epoll is preferred over select due to its efficiency and ability to
    handle more than 1024 simultaneous connections"""

    def __init__(self):
        self.epoll = select.epoll()
        # TODO should we set any other masks?
        # http://docs.python.org/library/select.html#epoll-objects
        self.READ_MASK = select.EPOLLIN | select.EPOLLPRI
        self.WRITE_MASK = select.EPOLLOUT
        self.ERR_MASK = select.EPOLLERR | select.EPOLLHUP

    def read(self, fileno):
        self.unregister(fileno)
        self.epoll.register(fileno, self.READ_MASK | self.ERR_MASK)

    def write(self, fileno):
        self.unregister(fileno)
        self.epoll.register(fileno, self.WRITE_MASK)

    def unregister(self, fileno):
        try:
            self.epoll.unregister(fileno)
        except Exception:
            pass

    def process(self, timeout):
        # poll() invokes a "long" syscall that will be interrupted by any signal
        # that comes in, causing an EINTR error.  If this happens, avoid dying
        # horribly by trying again with the appropriately shortened timout.
        process_time = py2_compatible_process_time()

        deadline = process_time + float(timeout or 0)
        poll_timeout = float(timeout or -1)
        while True:
            if timeout is not None and timeout > 0:
                poll_timeout = max(0, deadline - py2_compatible_process_time())
            try:
                msgs = self.epoll.poll(timeout=poll_timeout)
                break
            except IOError as e:
                if e.errno == errno.EINTR:
                    continue
                else:
                    raise

        rset = []
        wset = []
        xset = []
        for fd, mask in msgs:
            if mask & self.READ_MASK:
                rset.append(fd)
            if mask & self.WRITE_MASK:
                wset.append(fd)
            if mask & self.ERR_MASK:
                xset.append(fd)

        return rset, wset, xset


class ConnectionSelect:
    def __init__(self):
        self.readable = set()
        self.writable = set()

    def read(self, fileno):
        if fileno in self.writable:
            self.writable.remove(fileno)
        self.readable.add(fileno)

    def write(self, fileno):
        if fileno in self.readable:
            self.readable.remove(fileno)
        self.writable.add(fileno)

    def unregister(self, fileno):
        if fileno in self.readable:
            self.readable.remove(fileno)
        elif fileno in self.writable:
            self.writable.remove(fileno)

    def registered(self, fileno):
        return fileno in self.readable or fileno in self.writable

    def process(self, timeout):
        # select() invokes a "long" syscall that will be interrupted by any
        # signal that comes in, causing an EINTR error.  If this happens,
        # avoid dying horribly by trying again with the appropriately
        # shortened timout.
        deadline = py2_compatible_process_time() + float(timeout or 0)
        poll_timeout = timeout if timeout is None or timeout > 0 else None
        while True:
            if timeout is not None and timeout > 0:
                poll_timeout = max(0, deadline - py2_compatible_process_time())
            try:
                return select.select(
                    list(self.readable),
                    list(self.writable),
                    list(self.readable),
                    poll_timeout,
                )
            except IOError as e:
                if e.errno == errno.EINTR:
                    continue
                else:
                    raise


class TSocketBase(TTransportBase):
    """Base class for both connected and listening sockets"""

    def __init__(self):
        self.handles = {}

    def _resolveAddr(self, family=None):
        if family is None:
            family = socket.AF_UNSPEC
        if self._unix_socket is not None:
            return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)]
        else:
            ai_flags = 0
            if self.host is None:
                ai_flags |= socket.AI_PASSIVE
            return socket.getaddrinfo(
                self.host, self.port, family, socket.SOCK_STREAM, 0, ai_flags
            )

    def close(self):
        klist = (
            self.handles.keys()
            if sys.version_info[0] < 3
            else list(self.handles.keys())
        )
        for key in klist:
            self.handles[key].close()
            del self.handles[key]

    def getSocketName(self):
        if not self.handles:
            raise TTransportException(
                TTransportException.NOT_OPEN, "Transport not open"
            )
        return next(iter(self.handles.values())).getsockname()

    def fileno(self):
        if not self.handles:
            raise TTransportException(
                TTransportException.NOT_OPEN, "Transport not open"
            )
        if sys.version_info[0] >= 3:
            return list(self.handles.values())[0].fileno()
        else:
            return self.handles.values()[0].fileno()

    def setCloseOnExec(self, closeOnExec):
        self.close_on_exec = closeOnExec
        for handle in self.handles.values():
            self._setHandleCloseOnExec(handle)

    def _setHandleCloseOnExec(self, handle):
        # Windows doesn't have this module, don't set the handle in this case.
        if fcntl is None:
            return

        flags = fcntl.fcntl(handle, fcntl.F_GETFD, 0)
        if flags < 0:
            raise IOError("Error in retrieving file options")
        if self.close_on_exec:
            fcntl.fcntl(handle, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
        else:
            fcntl.fcntl(handle, fcntl.F_SETFD, flags & ~fcntl.FD_CLOEXEC)


class TSocket(TSocketBase):
    """Connection Socket implementation of TTransport base."""

    def __init__(self, host="localhost", port=9090, unix_socket=None, family=None):
        """Initialize a TSocket

        @param host(str)  The host to connect to.
        @param port(int)  The (TCP) port to connect to.
        @param unix_socket(str)  The filename of a unix socket to connect to.
                                 (host and port will be ignored.)
        @param family(int)  Address family for connection. Ignored if
                            unix_socket is specified.
        """
        TSocketBase.__init__(self)
        self.host = host
        self.port = port
        self.handle = None
        self.family = family
        self._unix_socket = unix_socket
        self._timeout = None
        self.close_on_exec = True
        if not unix_socket:
            self.port = int(self.port)

    def __enter__(self):
        if not self.isOpen():
            self.open()
        return self

    def __exit__(self, type, value, traceback):
        if self.isOpen():
            self.close()

    def setHandle(self, h):
        self.handle = h
        self.handles[h.fileno()] = h

    def getHandle(self):
        return self.handle

    def close(self):
        TSocketBase.close(self)
        self.handle = None

    def isOpen(self):
        return self.handle is not None

    def setTimeout(self, ms):
        if ms is None:
            self._timeout = None
        else:
            self._timeout = ms / 1000.0

        if self.handle is not None:
            self.handle.settimeout(self._timeout)

    def getPeerName(self):
        if not self.handle:
            raise TTransportException(
                TTransportException.NOT_OPEN, "Transport not open"
            )
        return self.handle.getpeername()

    def open(self):
        address = None
        try:
            res0 = self._resolveAddr(self.family)
            for res in res0:
                address = res[4]
                handle = socket.socket(res[0], res[1])
                self.setHandle(handle)
                handle.settimeout(self._timeout)
                self.setCloseOnExec(self.close_on_exec)
                try:
                    handle.connect(address)
                except socket.error:
                    self.close()
                    if res is not res0[-1]:
                        continue
                    else:
                        raise
                break
        except socket.error as e:
            if self._unix_socket:
                msg = "socket error connecting to path %s: %s" % (
                    self._unix_socket,
                    repr(e),
                )
            else:
                msg = "socket error connecting to host %s, port %s (%s): %s" % (
                    self.host,
                    self.port,
                    repr(address),
                    repr(e),
                )
            raise TTransportException(TTransportException.NOT_OPEN, msg)

    def read(self, sz):
        try:
            buff = self.handle.recv(sz)
            if len(buff) == 0:
                raise TTransportException(
                    type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes"
                )
        except socket.error as e:
            raise TTransportException(
                type=TTransportException.END_OF_FILE,
                message="Socket read failed: {}".format(str(e)),
            )
        return buff

    def write(self, buff):
        if not self.handle:
            raise TTransportException(
                TTransportException.NOT_OPEN, "Transport not open"
            )
        sent = 0
        have = len(buff)
        while sent < have:
            try:
                plus = self.handle.send(buff)
            except socket.error as e:
                raise TTransportException(
                    type=TTransportException.END_OF_FILE,
                    message="Socket write failed: {}".format(str(e)),
                )
            assert plus > 0
            sent += plus
            buff = buff[plus:]

    def flush(self):
        pass


class TServerSocket(TSocketBase, TServerTransportBase):
    """Socket implementation of TServerTransport base."""

    def __init__(self, port=9090, unix_socket=None, family=None, backlog=128):
        """Initialize a TServerSocket

        @param family(int): address family for connections. Ignored if
                            unix_socket is specified.
        @param host(str)  The host to connect to.
        @param port(int)  The (TCP) port to connect to.
        @param unix_socket(str)  The filename of a unix socket to connect to.
                                 (host, port, and family will be ignored.)
        @param backlog(int): maximum number of connections in listen queue.
        """
        TSocketBase.__init__(self)
        self.host = None
        self.port = port
        self._unix_socket = unix_socket
        self.family = family
        self.tcp_backlog = backlog
        self.close_on_exec = True
        if not unix_socket:
            self.port = int(self.port)

        # Since we now rely on select() by default to do accepts across
        # multiple socket fds, we can receive two connections concurrently.
        # In order to maintain compatibility with the existing .accept() API,
        # we need to keep track of the accept backlog.
        self._queue = []

    def __enter__(self):
        if not self.isListening():
            self.listen()
        return self

    def __exit__(self, type, value, traceback):
        if self.isListening():
            self.close()

    def getSocketName(self):
        warnings.warn(
            "getSocketName() is deprecated for TServerSocket.  "
            "Please use getSocketNames() instead."
        )
        return TSocketBase.getSocketName(self)

    def getSocketNames(self):
        return [handle.getsockname() for handle in self.handles.values()]

    def fileno(self):
        warnings.warn(
            "fileno() is deprecated for TServerSocket.  "
            "Please use filenos() instead."
        )
        return TSocketBase.fileno(self)

    def filenos(self):
        return [handle.fileno() for handle in self.handles.values()]

    def _cleanup_unix_socket(self, addrinfo):
        tmp = socket.socket(addrinfo[0], addrinfo[1])
        try:
            tmp.connect(addrinfo[4])
        except socket.error as err:
            eno, message = err.args
            if eno == errno.ECONNREFUSED:
                os.unlink(addrinfo[4])

    def isListening(self):
        return bool(self.handles)

    def listen(self):
        res0 = self._resolveAddr(self.family)

        for res in res0:
            if res[0] == socket.AF_INET6 and res[4][0] == socket.AF_INET6:
                # This happens if your version of python was built without IPv6
                # support.  getaddrinfo() will return IPv6 addresses, but the
                # contents of the address field are bogus.
                # (For example, see http://bugs.python.org/issue8858)
                #
                # Ignore IPv6 addresses if python doesn't have IPv6 support.
                continue

            # We need remove the old unix socket if the file exists and
            # nobody is listening on it.
            if self._unix_socket:
                self._cleanup_unix_socket(res)

            # Don't complain if we can't create a socket
            # since this is handled below.
            try:
                handle = socket.socket(res[0], res[1])
            except Exception:
                continue
            handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self._setHandleCloseOnExec(handle)

            # Always set IPV6_V6ONLY for IPv6 sockets when not on Windows
            if res[0] == socket.AF_INET6 and sys.platform != "win32":
                handle.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)

            handle.settimeout(None)
            handle.bind(res[4])
            handle.listen(self.tcp_backlog)

            self.handles[handle.fileno()] = handle

        if not self.handles:
            raise TTransportException("No valid interfaces to listen on!")

    def _sock_accept(self):
        if self._queue:
            return self._queue.pop()

        if hasattr(select, "epoll"):
            poller = ConnectionEpoll()
        else:
            poller = ConnectionSelect()

        for filenos in self.handles.keys():
            poller.read(filenos)

        r, _, x = poller.process(0)

        for fd in r:
            self._queue.append(self.handles[fd].accept())

        if not self._queue:
            raise TTransportException("Accept interrupt without client?")

        return self._queue.pop()

    def accept(self):
        return self._makeTSocketFromAccepted(self._sock_accept())

    def _makeTSocketFromAccepted(self, accepted):
        client, addr = accepted
        result = TSocket()
        result.setHandle(client)
        return result
