thrift/lib/py/transport/TSocket.py (354 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-unsafe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import errno
import os
import select
import socket
import sys
import time
import warnings
from thrift.transport.TTransport import (
TTransportBase,
TTransportException,
TServerTransportBase,
)
try:
import fcntl
except ImportError:
# Windows doesn't have this module
fcntl = None
def py2_compatible_process_time():
if sys.version_info.major >= 3 and sys.version_info.minor >= 3:
return time.process_time()
else:
return time.clock()
class ConnectionEpoll:
"""epoll is preferred over select due to its efficiency and ability to
handle more than 1024 simultaneous connections"""
def __init__(self):
self.epoll = select.epoll()
# TODO should we set any other masks?
# http://docs.python.org/library/select.html#epoll-objects
self.READ_MASK = select.EPOLLIN | select.EPOLLPRI
self.WRITE_MASK = select.EPOLLOUT
self.ERR_MASK = select.EPOLLERR | select.EPOLLHUP
def read(self, fileno):
self.unregister(fileno)
self.epoll.register(fileno, self.READ_MASK | self.ERR_MASK)
def write(self, fileno):
self.unregister(fileno)
self.epoll.register(fileno, self.WRITE_MASK)
def unregister(self, fileno):
try:
self.epoll.unregister(fileno)
except Exception:
pass
def process(self, timeout):
# poll() invokes a "long" syscall that will be interrupted by any signal
# that comes in, causing an EINTR error. If this happens, avoid dying
# horribly by trying again with the appropriately shortened timout.
process_time = py2_compatible_process_time()
deadline = process_time + float(timeout or 0)
poll_timeout = float(timeout or -1)
while True:
if timeout is not None and timeout > 0:
poll_timeout = max(0, deadline - py2_compatible_process_time())
try:
msgs = self.epoll.poll(timeout=poll_timeout)
break
except IOError as e:
if e.errno == errno.EINTR:
continue
else:
raise
rset = []
wset = []
xset = []
for fd, mask in msgs:
if mask & self.READ_MASK:
rset.append(fd)
if mask & self.WRITE_MASK:
wset.append(fd)
if mask & self.ERR_MASK:
xset.append(fd)
return rset, wset, xset
class ConnectionSelect:
def __init__(self):
self.readable = set()
self.writable = set()
def read(self, fileno):
if fileno in self.writable:
self.writable.remove(fileno)
self.readable.add(fileno)
def write(self, fileno):
if fileno in self.readable:
self.readable.remove(fileno)
self.writable.add(fileno)
def unregister(self, fileno):
if fileno in self.readable:
self.readable.remove(fileno)
elif fileno in self.writable:
self.writable.remove(fileno)
def registered(self, fileno):
return fileno in self.readable or fileno in self.writable
def process(self, timeout):
# select() invokes a "long" syscall that will be interrupted by any
# signal that comes in, causing an EINTR error. If this happens,
# avoid dying horribly by trying again with the appropriately
# shortened timout.
deadline = py2_compatible_process_time() + float(timeout or 0)
poll_timeout = timeout if timeout is None or timeout > 0 else None
while True:
if timeout is not None and timeout > 0:
poll_timeout = max(0, deadline - py2_compatible_process_time())
try:
return select.select(
list(self.readable),
list(self.writable),
list(self.readable),
poll_timeout,
)
except IOError as e:
if e.errno == errno.EINTR:
continue
else:
raise
class TSocketBase(TTransportBase):
"""Base class for both connected and listening sockets"""
def __init__(self):
self.handles = {}
def _resolveAddr(self, family=None):
if family is None:
family = socket.AF_UNSPEC
if self._unix_socket is not None:
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)]
else:
ai_flags = 0
if self.host is None:
ai_flags |= socket.AI_PASSIVE
return socket.getaddrinfo(
self.host, self.port, family, socket.SOCK_STREAM, 0, ai_flags
)
def close(self):
klist = (
self.handles.keys()
if sys.version_info[0] < 3
else list(self.handles.keys())
)
for key in klist:
self.handles[key].close()
del self.handles[key]
def getSocketName(self):
if not self.handles:
raise TTransportException(
TTransportException.NOT_OPEN, "Transport not open"
)
return next(iter(self.handles.values())).getsockname()
def fileno(self):
if not self.handles:
raise TTransportException(
TTransportException.NOT_OPEN, "Transport not open"
)
if sys.version_info[0] >= 3:
return list(self.handles.values())[0].fileno()
else:
return self.handles.values()[0].fileno()
def setCloseOnExec(self, closeOnExec):
self.close_on_exec = closeOnExec
for handle in self.handles.values():
self._setHandleCloseOnExec(handle)
def _setHandleCloseOnExec(self, handle):
# Windows doesn't have this module, don't set the handle in this case.
if fcntl is None:
return
flags = fcntl.fcntl(handle, fcntl.F_GETFD, 0)
if flags < 0:
raise IOError("Error in retrieving file options")
if self.close_on_exec:
fcntl.fcntl(handle, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
else:
fcntl.fcntl(handle, fcntl.F_SETFD, flags & ~fcntl.FD_CLOEXEC)
class TSocket(TSocketBase):
"""Connection Socket implementation of TTransport base."""
def __init__(self, host="localhost", port=9090, unix_socket=None, family=None):
"""Initialize a TSocket
@param host(str) The host to connect to.
@param port(int) The (TCP) port to connect to.
@param unix_socket(str) The filename of a unix socket to connect to.
(host and port will be ignored.)
@param family(int) Address family for connection. Ignored if
unix_socket is specified.
"""
TSocketBase.__init__(self)
self.host = host
self.port = port
self.handle = None
self.family = family
self._unix_socket = unix_socket
self._timeout = None
self.close_on_exec = True
if not unix_socket:
self.port = int(self.port)
def __enter__(self):
if not self.isOpen():
self.open()
return self
def __exit__(self, type, value, traceback):
if self.isOpen():
self.close()
def setHandle(self, h):
self.handle = h
self.handles[h.fileno()] = h
def getHandle(self):
return self.handle
def close(self):
TSocketBase.close(self)
self.handle = None
def isOpen(self):
return self.handle is not None
def setTimeout(self, ms):
if ms is None:
self._timeout = None
else:
self._timeout = ms / 1000.0
if self.handle is not None:
self.handle.settimeout(self._timeout)
def getPeerName(self):
if not self.handle:
raise TTransportException(
TTransportException.NOT_OPEN, "Transport not open"
)
return self.handle.getpeername()
def open(self):
address = None
try:
res0 = self._resolveAddr(self.family)
for res in res0:
address = res[4]
handle = socket.socket(res[0], res[1])
self.setHandle(handle)
handle.settimeout(self._timeout)
self.setCloseOnExec(self.close_on_exec)
try:
handle.connect(address)
except socket.error:
self.close()
if res is not res0[-1]:
continue
else:
raise
break
except socket.error as e:
if self._unix_socket:
msg = "socket error connecting to path %s: %s" % (
self._unix_socket,
repr(e),
)
else:
msg = "socket error connecting to host %s, port %s (%s): %s" % (
self.host,
self.port,
repr(address),
repr(e),
)
raise TTransportException(TTransportException.NOT_OPEN, msg)
def read(self, sz):
try:
buff = self.handle.recv(sz)
if len(buff) == 0:
raise TTransportException(
type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes"
)
except socket.error as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message="Socket read failed: {}".format(str(e)),
)
return buff
def write(self, buff):
if not self.handle:
raise TTransportException(
TTransportException.NOT_OPEN, "Transport not open"
)
sent = 0
have = len(buff)
while sent < have:
try:
plus = self.handle.send(buff)
except socket.error as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message="Socket write failed: {}".format(str(e)),
)
assert plus > 0
sent += plus
buff = buff[plus:]
def flush(self):
pass
class TServerSocket(TSocketBase, TServerTransportBase):
"""Socket implementation of TServerTransport base."""
def __init__(self, port=9090, unix_socket=None, family=None, backlog=128):
"""Initialize a TServerSocket
@param family(int): address family for connections. Ignored if
unix_socket is specified.
@param host(str) The host to connect to.
@param port(int) The (TCP) port to connect to.
@param unix_socket(str) The filename of a unix socket to connect to.
(host, port, and family will be ignored.)
@param backlog(int): maximum number of connections in listen queue.
"""
TSocketBase.__init__(self)
self.host = None
self.port = port
self._unix_socket = unix_socket
self.family = family
self.tcp_backlog = backlog
self.close_on_exec = True
if not unix_socket:
self.port = int(self.port)
# Since we now rely on select() by default to do accepts across
# multiple socket fds, we can receive two connections concurrently.
# In order to maintain compatibility with the existing .accept() API,
# we need to keep track of the accept backlog.
self._queue = []
def __enter__(self):
if not self.isListening():
self.listen()
return self
def __exit__(self, type, value, traceback):
if self.isListening():
self.close()
def getSocketName(self):
warnings.warn(
"getSocketName() is deprecated for TServerSocket. "
"Please use getSocketNames() instead."
)
return TSocketBase.getSocketName(self)
def getSocketNames(self):
return [handle.getsockname() for handle in self.handles.values()]
def fileno(self):
warnings.warn(
"fileno() is deprecated for TServerSocket. "
"Please use filenos() instead."
)
return TSocketBase.fileno(self)
def filenos(self):
return [handle.fileno() for handle in self.handles.values()]
def _cleanup_unix_socket(self, addrinfo):
tmp = socket.socket(addrinfo[0], addrinfo[1])
try:
tmp.connect(addrinfo[4])
except socket.error as err:
eno, message = err.args
if eno == errno.ECONNREFUSED:
os.unlink(addrinfo[4])
def isListening(self):
return bool(self.handles)
def listen(self):
res0 = self._resolveAddr(self.family)
for res in res0:
if res[0] == socket.AF_INET6 and res[4][0] == socket.AF_INET6:
# This happens if your version of python was built without IPv6
# support. getaddrinfo() will return IPv6 addresses, but the
# contents of the address field are bogus.
# (For example, see http://bugs.python.org/issue8858)
#
# Ignore IPv6 addresses if python doesn't have IPv6 support.
continue
# We need remove the old unix socket if the file exists and
# nobody is listening on it.
if self._unix_socket:
self._cleanup_unix_socket(res)
# Don't complain if we can't create a socket
# since this is handled below.
try:
handle = socket.socket(res[0], res[1])
except Exception:
continue
handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._setHandleCloseOnExec(handle)
# Always set IPV6_V6ONLY for IPv6 sockets when not on Windows
if res[0] == socket.AF_INET6 and sys.platform != "win32":
handle.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
handle.settimeout(None)
handle.bind(res[4])
handle.listen(self.tcp_backlog)
self.handles[handle.fileno()] = handle
if not self.handles:
raise TTransportException("No valid interfaces to listen on!")
def _sock_accept(self):
if self._queue:
return self._queue.pop()
if hasattr(select, "epoll"):
poller = ConnectionEpoll()
else:
poller = ConnectionSelect()
for filenos in self.handles.keys():
poller.read(filenos)
r, _, x = poller.process(0)
for fd in r:
self._queue.append(self.handles[fd].accept())
if not self._queue:
raise TTransportException("Accept interrupt without client?")
return self._queue.pop()
def accept(self):
return self._makeTSocketFromAccepted(self._sock_accept())
def _makeTSocketFromAccepted(self, accepted):
client, addr = accepted
result = TSocket()
result.setHandle(client)
return result