util/ClientSocket.cpp (169 lines of code) (raw):

/** * Copyright (c) 2014-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include <wdt/util/ClientSocket.h> #include <wdt/Reporting.h> #include <fcntl.h> #include <glog/logging.h> #include <poll.h> #include <sys/socket.h> #include <folly/Conv.h> #include <folly/ScopeGuard.h> namespace facebook { namespace wdt { using std::string; ClientSocket::ClientSocket(ThreadCtx &threadCtx, const string &dest, const int port, const EncryptionParams &encryptionParams, int64_t ivChangeInterval) : dest_(dest), threadCtx_(threadCtx) { memset(&sa_, 0, sizeof(sa_)); socket_ = std::make_unique<WdtSocket>(threadCtx, port, encryptionParams, ivChangeInterval, nullptr); if (threadCtx_.getOptions().ipv6) { sa_.ai_family = AF_INET6; } if (threadCtx_.getOptions().ipv4) { sa_.ai_family = AF_INET; } sa_.ai_socktype = SOCK_STREAM; } ErrorCode ClientSocket::connect() { auto fd = socket_->getFd(); auto port = socket_->getPort(); WDT_CHECK(fd < 0) << "Previous connection not closed " << fd << " " << port; // Lookup struct addrinfo *infoList = nullptr; auto guard = folly::makeGuard([&] { if (infoList) { freeaddrinfo(infoList); } }); string portStr = folly::to<string>(port); int res = getaddrinfo(dest_.c_str(), portStr.c_str(), &sa_, &infoList); if (res) { // not errno, can't use WPLOG (perror) WLOG(ERROR) << "Failed getaddrinfo " << dest_ << " , " << port << " : " << res << " : " << gai_strerror(res); return CONN_ERROR; } int count = 0; for (struct addrinfo *info = infoList; info != nullptr; info = info->ai_next) { ++count; std::string host, port; WdtSocket::getNameInfo(info->ai_addr, info->ai_addrlen, host, port); WVLOG(2) << "will connect to " << host << " " << port; fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol); if (fd == -1) { WPLOG(WARNING) << "Error making socket for port " << port; continue; } WVLOG(1) << "new socket " << fd << " for port " << port; socket_->setFd(fd); setSendBufferSize(); // make the socket non blocking int sockArg = fcntl(fd, F_GETFL, nullptr); sockArg |= O_NONBLOCK; res = fcntl(fd, F_SETFL, sockArg); if (res < 0) { WPLOG(ERROR) << "Failed to make the socket non-blocking " << port << " sock " << sockArg << " res " << res; closeConnection(); continue; } if (::connect(fd, info->ai_addr, info->ai_addrlen) != 0) { if (errno != EINPROGRESS) { WPLOG(INFO) << "Error connecting on " << host << " " << port; closeConnection(); continue; } auto startTime = Clock::now(); int connectTimeout = threadCtx_.getOptions().connect_timeout_millis; while (true) { // check for abort if (threadCtx_.getAbortChecker()->shouldAbort()) { WLOG(ERROR) << "Transfer aborted during connect " << port << " " << fd; closeConnection(); return ABORT; } // we need this loop because poll() can return before any file handles // have changes or before timing out. In that case, we check whether it // is because of EINTR or not. If true, we have to try poll with // reduced timeout. Also we set the poll timeout to be at max equal to // abort check interval. This allows us to check for abort regularly. int timeElapsed = durationMillis(Clock::now() - startTime); if (timeElapsed >= connectTimeout) { WVLOG(1) << "connect() timed out" << host << " " << port; closeConnection(); return CONN_ERROR_RETRYABLE; } int pollTimeout = std::min(connectTimeout - timeElapsed, threadCtx_.getOptions().abort_check_interval_millis); struct pollfd pollFds[] = {{fd, POLLOUT, 0}}; if ((res = poll(pollFds, 1, pollTimeout)) <= 0) { if (errno == EINTR) { WVLOG(1) << "poll() call interrupted. retrying... " << port; continue; } if (res == 0) { WVLOG(1) << "poll() timed out " << host << " " << port; continue; } WPLOG(ERROR) << "poll() failed " << host << " " << port << " " << fd; closeConnection(); return CONN_ERROR; } break; } // have to check whether the connection attempt succeeded int connectResult; socklen_t len = sizeof(connectResult); if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &connectResult, &len) < 0) { WPLOG(WARNING) << "getsockopt() failed"; closeConnection(); continue; } if (connectResult != 0) { WLOG(WARNING) << "connect did not succeed on " << host << " " << port << " : " << strerrorStr(connectResult); closeConnection(); continue; } } // Set to blocking mode again sockArg = fcntl(fd, F_GETFL, nullptr); sockArg &= (~O_NONBLOCK); res = fcntl(fd, F_SETFL, sockArg); if (res == -1) { WPLOG(ERROR) << "Could not make the socket blocking " << port; closeConnection(); continue; } WVLOG(1) << "Successful connect on " << fd; peerIp_ = host; sa_ = *info; break; } if (socket_->getFd() < 0) { if (count > 1) { // Only log this if not redundant with log above (ie --ipv6=false) WLOG(INFO) << "Unable to connect to either of the " << count << " addrs"; } return CONN_ERROR_RETRYABLE; } socket_->setSocketTimeouts(); socket_->setDscp(threadCtx_.getOptions().dscp); return OK; } const std::string &ClientSocket::getPeerIp() const { return peerIp_; } void ClientSocket::setSendBufferSize() { int bufSize = threadCtx_.getOptions().send_buffer_size; auto fd = socket_->getFd(); auto port = socket_->getPort(); if (bufSize <= 0) { return; } int status = ::setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize)); if (status != 0) { WPLOG(ERROR) << "Failed to set send buffer " << port << " size " << bufSize << " fd " << fd; return; } WVLOG(1) << "Send buffer size set to " << bufSize << " port " << port; } ClientSocket::~ClientSocket() { } } } // end namespace facebook::wdt