util/ServerSocket.cpp (283 lines of code) (raw):
/**
* Copyright (c) 2014-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <wdt/util/ServerSocket.h>
#include <fcntl.h>
#include <folly/Conv.h>
#include <glog/logging.h>
#include <poll.h>
#include <sys/socket.h>
#include <algorithm>
namespace facebook {
namespace wdt {
ServerSocket::ServerSocket(ThreadCtx &threadCtx, int port, int backlog,
const EncryptionParams &encryptionParams,
int64_t ivChangeInterval,
Func &&tagVerificationSuccessCallback)
: threadCtx_(threadCtx), backlog_(backlog) {
socket_ = std::make_unique<WdtSocket>(threadCtx, port, encryptionParams,
ivChangeInterval,
std::move(tagVerificationSuccessCallback));
// for backward compatibility
socket_->enableUnencryptedPeerSupport();
}
void ServerSocket::closeAllNoCheck() {
int port = socket_->getPort();
int fd_ = socket_->getFd();
WVLOG(1) << "Destroying server socket (port, listen fd, fd) " << port << ", "
<< listeningFds_ << ", " << fd_;
closeNoCheck();
// We don't care about listen error, the error that matters is encryption err
for (auto listeningFd : listeningFds_) {
if (listeningFd >= 0) {
int ret = ::close(listeningFd);
if (ret != 0) {
WPLOG(ERROR)
<< "Error closing listening fd for server socket. listeningFd: "
<< listeningFd << " port: " << port;
}
}
}
listeningFds_.clear();
}
ServerSocket::~ServerSocket() {
closeAllNoCheck();
}
int ServerSocket::listenInternal(struct addrinfo *info,
const std::string &host) {
int port = socket_->getPort();
WVLOG(1) << "Will listen on " << host << " " << port << " "
<< info->ai_family;
int listeningFd =
socket(info->ai_family, info->ai_socktype, info->ai_protocol);
if (listeningFd == -1) {
WPLOG(WARNING) << "Error making server socket " << host << " " << port;
return -1;
}
setReceiveBufferSize(listeningFd);
int optval = 1;
if (setsockopt(listeningFd, SOL_SOCKET, SO_REUSEADDR, &optval,
sizeof(optval)) != 0) {
WPLOG(ERROR) << "Unable to set SO_REUSEADDR option " << host << " "
<< port;
}
if (info->ai_family == AF_INET6) {
// for ipv6 address, turn on ipv6 only flag
if (setsockopt(listeningFd, IPPROTO_IPV6, IPV6_V6ONLY, &optval,
sizeof(optval)) != 0) {
WPLOG(ERROR) << "Unable to set IPV6_V6ONLY flag " << host << " " << port;
}
}
if (bind(listeningFd, info->ai_addr, info->ai_addrlen)) {
WPLOG(WARNING) << "Error binding " << host << " " << port;
::close(listeningFd);
return -1;
}
if (::listen(listeningFd, backlog_)) {
WPLOG(ERROR) << "listen error for port " << host << " " << port;
::close(listeningFd);
return -1;
}
return listeningFd;
}
int ServerSocket::getSelectedPortAndNewAddress(int listeningFd,
struct addrinfo &sa,
const std::string &host,
addrInfoList &infoList) {
int port;
struct sockaddr_in sin;
socklen_t len = sizeof(sin);
if (getsockname(listeningFd, (struct sockaddr *)&sin, &len) == -1) {
WPLOG(ERROR) << "getsockname failed " << host;
return -1;
}
port = ntohs(sin.sin_port);
WVLOG(1) << "auto configuring port to " << port;
std::string portStr = folly::to<std::string>(port);
int res = getaddrinfo(nullptr, portStr.c_str(), &sa, &infoList);
if (res) {
WLOG(ERROR) << "getaddrinfo failed " << host << " " << port << " : "
<< gai_strerror(res);
return -1;
}
if (infoList == nullptr) {
WLOG(ERROR) << "getaddrinfo unexpectedly returned nullptr " << host << " "
<< port;
return -1;
}
return port;
}
ErrorCode ServerSocket::listen() {
if (!listeningFds_.empty()) {
return OK;
}
struct addrinfo sa;
int port = socket_->getPort();
memset(&sa, 0, sizeof(sa));
const WdtOptions &options = threadCtx_.getOptions();
if (options.ipv6) {
sa.ai_family = AF_INET6;
}
if (options.ipv4) {
sa.ai_family = AF_INET;
}
sa.ai_socktype = SOCK_STREAM;
sa.ai_flags = AI_PASSIVE;
// Dynamic port is the default on receiver (and setting the start_port flag
// explictly automatically also sets static_ports to false)
if (!options.static_ports) {
WVLOG(1) << "Not using static_ports, changing port " << port << " to 0";
port = 0;
}
// Lookup
addrInfoList infoList = nullptr;
std::string portStr = folly::to<std::string>(port);
int res = getaddrinfo(nullptr, portStr.c_str(), &sa, &infoList);
if (res) {
// not errno, can't use WPLOG (perror)
WLOG(ERROR) << "Failed getaddrinfo ai_passive on " << port << " : " << res
<< " : " << gai_strerror(res);
return CONN_ERROR;
}
// if the port specified is 0, then a random port is selected for the first
// address. We use that same port for other address types. Another addrinfo
// list is created using the new port. This variable is used to ensure that we
// do not try to bind again to the previous type.
int addressTypeAlreadyBound = AF_UNSPEC;
for (struct addrinfo *info = infoList; info != nullptr;) {
if (info->ai_family == addressTypeAlreadyBound) {
// we are already listening for this address type
WVLOG(2) << "Ignoring address family " << info->ai_family
<< " since we are already listing on it " << port;
info = info->ai_next;
continue;
}
std::string host, portString;
if (WdtSocket::getNameInfo(info->ai_addr, info->ai_addrlen, host,
portString)) {
// even if getnameinfo fail, we can still continue. Error is logged inside
// SocketUtils
WDT_CHECK(port == folly::to<int32_t>(portString));
}
int listeningFd = listenInternal(info, host);
if (listeningFd < 0) {
info = info->ai_next;
continue;
}
int addressFamily = info->ai_family;
if (port == 0) {
addrInfoList newInfoList = nullptr;
int selectedPort =
getSelectedPortAndNewAddress(listeningFd, sa, host, newInfoList);
if (selectedPort < 0) {
::close(listeningFd);
info = info->ai_next;
continue;
}
port = selectedPort;
socket_->setPort(port);
addressTypeAlreadyBound = addressFamily;
freeaddrinfo(infoList);
infoList = newInfoList;
info = infoList;
} else {
info = info->ai_next;
}
WVLOG(1) << "Successful listen on " << listeningFd << " host " << host
<< " port " << port << " ai_family " << addressFamily;
listeningFds_.emplace_back(listeningFd);
}
freeaddrinfo(infoList);
if (listeningFds_.empty()) {
WLOG(ERROR) << "Unable to listen port " << port;
return CONN_ERROR_RETRYABLE;
}
return OK;
}
ErrorCode ServerSocket::acceptNextConnection(int timeoutMillis,
bool tryCurAddressFirst) {
ErrorCode code = listen();
if (code != OK) {
return code;
}
WDT_CHECK(!listeningFds_.empty());
WDT_CHECK(timeoutMillis > 0);
auto port = socket_->getPort();
auto fd = socket_->getFd();
const WdtOptions &options = threadCtx_.getOptions();
const bool checkAbort = (options.abort_check_interval_millis > 0);
const int numFds = listeningFds_.size();
struct pollfd pollFds[numFds];
auto startTime = Clock::now();
while (true) {
// we need this loop because poll() can return before any file handles
// have changes or before timing out. In that case, we check whether it
// is because of EINTR or not. If true, we have to try poll with
// reduced timeout
int timeElapsed = durationMillis(Clock::now() - startTime);
if (timeElapsed >= timeoutMillis) {
WVLOG(3) << "accept() timed out";
return CONN_ERROR;
}
int pollTimeout = timeoutMillis - timeElapsed;
if (checkAbort) {
if (threadCtx_.getAbortChecker()->shouldAbort()) {
WLOG(ERROR) << "Transfer aborted during accept " << port << " " << fd;
return ABORT;
}
pollTimeout = std::min(pollTimeout, options.abort_check_interval_millis);
}
for (int i = 0; i < numFds; i++) {
pollFds[i] = {listeningFds_[i], POLLIN, 0};
}
int retValue = poll(pollFds, numFds, pollTimeout);
if (retValue > 0) {
break;
}
if (errno == EINTR) {
WVLOG(1) << "poll() call interrupted. retrying...";
continue;
}
if (retValue == 0) {
WVLOG(3) << "poll() timed out on port : " << port
<< ", listening fds : " << listeningFds_;
continue;
}
WPLOG(ERROR) << "poll() failed on port : " << port
<< ", listening fds : " << listeningFds_;
return CONN_ERROR;
}
if (lastCheckedPollIndex_ >= numFds) {
// can happen if getaddrinfo returns different set of addresses
lastCheckedPollIndex_ = 0;
} else if (!tryCurAddressFirst) {
// else try the next address
lastCheckedPollIndex_ = (lastCheckedPollIndex_ + 1) % numFds;
}
for (int count = 0; count < numFds; count++) {
auto &pollFd = pollFds[lastCheckedPollIndex_];
if (pollFd.revents & POLLIN) {
struct sockaddr_storage addr;
socklen_t addrLen = sizeof(addr);
fd = accept(pollFd.fd, (struct sockaddr *)&addr, &addrLen);
if (fd < 0) {
WPLOG(ERROR) << "accept error";
return CONN_ERROR;
}
WdtSocket::getNameInfo((struct sockaddr *)&addr, addrLen,
peerIp_, peerPort_);
WVLOG(1) << "New connection, fd : " << fd << " from " << peerIp_ << " "
<< peerPort_;
socket_->setFd(fd);
socket_->setSocketTimeouts();
socket_->setDscp(options.dscp);
return OK;
}
lastCheckedPollIndex_ = (lastCheckedPollIndex_ + 1) % numFds;
}
WLOG(ERROR) << "None of the listening fds got a POLLIN event " << port;
return CONN_ERROR;
}
void ServerSocket::setReceiveBufferSize(int fd) {
auto port = socket_->getPort();
int bufSize = threadCtx_.getOptions().receive_buffer_size;
if (bufSize <= 0) {
return;
}
int status =
::setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
if (status != 0) {
WPLOG(ERROR) << "Failed to set receive buffer " << port << " size "
<< bufSize << " fd " << fd;
return;
}
WVLOG(1) << "Receive buffer size set to " << bufSize << " port " << port;
}
std::string ServerSocket::getPeerIp() const {
// we keep returning the peer ip for error printing
return peerIp_;
}
std::string ServerSocket::getPeerPort() const {
// we keep returning the peer port for error printing
return peerPort_;
}
int ServerSocket::getBackLog() const {
return backlog_;
}
}
} // end namespace facebook::wdt