common/protobuf/kudu/util/net/socket.cc (520 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "kudu/util/net/socket.h"
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include <cerrno>
#include <cinttypes>
#include <cstring>
#include <limits>
#include <ostream>
#include <string>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "kudu/gutil/basictypes.h"
#include "kudu/gutil/port.h"
#include "kudu/gutil/stringprintf.h"
#include "kudu/gutil/strings/substitute.h"
#include "kudu/util/debug/trace_event.h"
#include "kudu/util/errno.h"
#include "kudu/util/flag_tags.h"
#include "kudu/util/monotime.h"
#include "kudu/util/net/net_util.h"
#include "kudu/util/net/sockaddr.h"
#include "kudu/util/random.h"
#include "kudu/util/random_util.h"
#include "kudu/util/slice.h"
DEFINE_string(local_ip_for_outbound_sockets, "",
"IP to bind to when making outgoing socket connections. "
"This must be an IP address of the form A.B.C.D, not a hostname. "
"Advanced parameter, subject to change.");
TAG_FLAG(local_ip_for_outbound_sockets, experimental);
DEFINE_bool(socket_inject_short_recvs, false,
"Inject short recv() responses which return less data than "
"requested");
TAG_FLAG(socket_inject_short_recvs, hidden);
TAG_FLAG(socket_inject_short_recvs, unsafe);
using std::string;
using strings::Substitute;
namespace kudu {
Socket::Socket()
: fd_(-1) {
}
Socket::Socket(int fd)
: fd_(fd) {
}
Socket::Socket(Socket&& other) noexcept
: fd_(other.Release()) {
}
void Socket::Reset(int fd) {
ignore_result(Close());
fd_ = fd;
}
int Socket::Release() {
int fd = fd_;
fd_ = -1;
return fd;
}
Socket::~Socket() {
ignore_result(Close());
}
Status Socket::Close() {
if (fd_ < 0) {
return Status::OK();
}
int fd = fd_;
int ret;
RETRY_ON_EINTR(ret, ::close(fd));
if (ret < 0) {
int err = errno;
return Status::NetworkError("close error", ErrnoToString(err), err);
}
fd_ = -1;
return Status::OK();
}
Status Socket::Shutdown(bool shut_read, bool shut_write) {
DCHECK_GE(fd_, 0);
int flags = 0;
if (shut_read && shut_write) {
flags |= SHUT_RDWR;
} else if (shut_read) {
flags |= SHUT_RD;
} else if (shut_write) {
flags |= SHUT_WR;
}
if (::shutdown(fd_, flags) < 0) {
int err = errno;
return Status::NetworkError("shutdown error", ErrnoToString(err), err);
}
return Status::OK();
}
int Socket::GetFd() const {
return fd_;
}
bool Socket::IsTemporarySocketError(int err) {
return ((err == EAGAIN) || (err == EWOULDBLOCK) || (err == EINTR));
}
#if defined(__linux__)
Status Socket::Init(int family, int flags) {
int nonblocking_flag = (flags & FLAG_NONBLOCKING) ? SOCK_NONBLOCK : 0;
Reset(::socket(family, SOCK_STREAM | SOCK_CLOEXEC | nonblocking_flag, 0));
if (fd_ < 0) {
int err = errno;
return Status::NetworkError("error opening socket", ErrnoToString(err), err);
}
return Status::OK();
}
#else
Status Socket::Init(int family, int flags) {
Reset(::socket(family, SOCK_STREAM, 0));
if (fd_ < 0) {
int err = errno;
return Status::NetworkError("error opening socket", ErrnoToString(err), err);
}
RETURN_NOT_OK(SetNonBlocking(flags & FLAG_NONBLOCKING));
RETURN_NOT_OK(SetCloseOnExec());
// Disable SIGPIPE.
int set = 1;
RETURN_NOT_OK_PREPEND(SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, set),
"failed to set SO_NOSIGPIPE");
return Status::OK();
}
#endif // defined(__linux__)
Status Socket::SetNoDelay(bool enabled) {
int flag = enabled ? 1 : 0;
RETURN_NOT_OK_PREPEND(SetSockOpt(IPPROTO_TCP, TCP_NODELAY, flag),
"failed to set TCP_NODELAY");
return Status::OK();
}
Status Socket::SetTcpCork(bool enabled) {
#if defined(__linux__)
int flag = enabled ? 1 : 0;
RETURN_NOT_OK_PREPEND(SetSockOpt(IPPROTO_TCP, TCP_CORK, flag),
"failed to set TCP_CORK");
#endif // defined(__linux__)
// TODO(unknown): Use TCP_NOPUSH for OSX if perf becomes an issue.
return Status::OK();
}
Status Socket::SetNonBlocking(bool enabled) {
int curflags = ::fcntl(fd_, F_GETFL, 0);
if (curflags == -1) {
int err = errno;
return Status::NetworkError(
StringPrintf("Failed to get file status flags on fd %d", fd_),
ErrnoToString(err), err);
}
int newflags = (enabled) ? (curflags | O_NONBLOCK) : (curflags & ~O_NONBLOCK);
if (::fcntl(fd_, F_SETFL, newflags) == -1) {
int err = errno;
if (enabled) {
return Status::NetworkError(
StringPrintf("Failed to set O_NONBLOCK on fd %d", fd_),
ErrnoToString(err), err);
} else {
return Status::NetworkError(
StringPrintf("Failed to clear O_NONBLOCK on fd %d", fd_),
ErrnoToString(err), err);
}
}
return Status::OK();
}
Status Socket::IsNonBlocking(bool* is_nonblock) const {
int curflags = ::fcntl(fd_, F_GETFL, 0);
if (curflags == -1) {
int err = errno;
return Status::NetworkError(
StringPrintf("Failed to get file status flags on fd %d", fd_),
ErrnoToString(err), err);
}
*is_nonblock = ((curflags & O_NONBLOCK) != 0);
return Status::OK();
}
Status Socket::SetCloseOnExec() {
int curflags = fcntl(fd_, F_GETFD, 0);
if (curflags == -1) {
int err = errno;
Reset(-1);
return Status::NetworkError("fcntl(F_GETFD) error", ErrnoToString(err), err);
}
if (fcntl(fd_, F_SETFD, curflags | FD_CLOEXEC) == -1) {
int err = errno;
Reset(-1);
return Status::NetworkError("fcntl(F_SETFD) error", ErrnoToString(err), err);
}
return Status::OK();
}
Status Socket::SetSendTimeout(const MonoDelta& timeout) {
return SetTimeout(SO_SNDTIMEO, "SO_SNDTIMEO", timeout);
}
Status Socket::SetRecvTimeout(const MonoDelta& timeout) {
return SetTimeout(SO_RCVTIMEO, "SO_RCVTIMEO", timeout);
}
Status Socket::SetReuseAddr(bool flag) {
int int_flag = flag ? 1 : 0;
RETURN_NOT_OK_PREPEND(SetSockOpt(SOL_SOCKET, SO_REUSEADDR, int_flag),
"failed to set SO_REUSEADDR");
return Status::OK();
}
Status Socket::SetReusePort(bool flag) {
#ifdef SO_REUSEPORT
int int_flag = flag ? 1 : 0;
RETURN_NOT_OK_PREPEND(SetSockOpt(SOL_SOCKET, SO_REUSEPORT, int_flag),
"failed to set SO_REUSEPORT");
return Status::OK();
#else
return Status::NotSupported("failed to set SO_REUSEPORT: protocol not available");
#endif
}
Status Socket::BindAndListen(const Sockaddr &sockaddr,
int listen_queue_size) {
RETURN_NOT_OK(SetReuseAddr(true));
RETURN_NOT_OK(Bind(sockaddr));
RETURN_NOT_OK(Listen(listen_queue_size));
return Status::OK();
}
Status Socket::Listen(int listen_queue_size) {
if (listen(fd_, listen_queue_size)) {
int err = errno;
return Status::NetworkError("listen() error", ErrnoToString(err));
}
return Status::OK();
}
Status Socket::GetSocketAddress(Sockaddr *cur_addr) const {
struct sockaddr_storage ss;
socklen_t len = sizeof(ss);
DCHECK_GE(fd_, 0);
if (::getsockname(fd_, reinterpret_cast<struct sockaddr*>(&ss), &len) == -1) {
int err = errno;
return Status::NetworkError("getsockname error", ErrnoToString(err), err);
}
*cur_addr = Sockaddr(reinterpret_cast<struct sockaddr&>(ss), len);
return Status::OK();
}
Status Socket::GetPeerAddress(Sockaddr *cur_addr) const {
struct sockaddr_storage addr;
socklen_t len = sizeof(addr);
DCHECK_GE(fd_, 0);
if (::getpeername(fd_, reinterpret_cast<struct sockaddr*>(&addr), &len) == -1) {
int err = errno;
return Status::NetworkError("getpeername error", ErrnoToString(err), err);
}
*cur_addr = Sockaddr(reinterpret_cast<const sockaddr&>(addr), len);
return Status::OK();
}
bool Socket::IsLoopbackConnection() const {
Sockaddr local, remote;
if (!GetSocketAddress(&local).ok()) return false;
if (!GetPeerAddress(&remote).ok()) return false;
// Check if remote address is in 127.0.0.0/8 subnet.
if (remote.IsAnyLocalAddress()) {
return true;
}
// Compare local and remote addresses without comparing ports.
local.set_port(0);
remote.set_port(0);
return local == remote;
}
Status Socket::Bind(const Sockaddr& bind_addr) {
DCHECK_GE(fd_, 0);
if (PREDICT_FALSE(::bind(fd_, bind_addr.addr(), bind_addr.addrlen()))) {
int err = errno;
Status s = Status::NetworkError(
strings::Substitute("error binding socket to $0: $1",
bind_addr.ToString(), ErrnoToString(err)),
Slice(), err);
if (s.IsNetworkError() && bind_addr.is_ip() &&
s.posix_code() == EADDRINUSE && bind_addr.port() != 0) {
TryRunLsof(bind_addr);
}
return s;
}
return Status::OK();
}
Status Socket::Accept(Socket *new_conn, Sockaddr *remote, int flags) {
TRACE_EVENT0("net", "Socket::Accept");
struct sockaddr_storage addr;
socklen_t olen = sizeof(addr);
DCHECK_GE(fd_, 0);
#if defined(__linux__)
int accept_flags = SOCK_CLOEXEC;
if (flags & FLAG_NONBLOCKING) {
accept_flags |= SOCK_NONBLOCK;
}
int fd = -1;
RETRY_ON_EINTR(fd, accept4(fd_, (struct sockaddr*)&addr,
&olen, accept_flags));
if (fd < 0) {
int err = errno;
return Status::NetworkError("accept4(2) error", ErrnoToString(err), err);
}
new_conn->Reset(fd);
#else
int fd = -1;
RETRY_ON_EINTR(fd, accept(fd_, (struct sockaddr*)&addr, &olen));
if (fd < 0) {
int err = errno;
return Status::NetworkError("accept(2) error", ErrnoToString(err), err);
}
new_conn->Reset(fd);
RETURN_NOT_OK(new_conn->SetNonBlocking(flags & FLAG_NONBLOCKING));
RETURN_NOT_OK(new_conn->SetCloseOnExec());
#endif // defined(__linux__)
*remote = Sockaddr(reinterpret_cast<const sockaddr&>(addr), olen);
TRACE_EVENT_INSTANT1("net", "Accepted", TRACE_EVENT_SCOPE_THREAD,
"remote", remote->ToString());
return Status::OK();
}
Status Socket::BindForOutgoingConnection() {
Sockaddr bind_host;
Status s = bind_host.ParseString(FLAGS_local_ip_for_outbound_sockets, 0);
CHECK(s.ok() && bind_host.port() == 0)
<< "Invalid local IP set for 'local_ip_for_outbound_sockets': '"
<< FLAGS_local_ip_for_outbound_sockets << "': " << s.ToString();
RETURN_NOT_OK(Bind(bind_host));
return Status::OK();
}
Status Socket::Connect(const Sockaddr &remote) {
TRACE_EVENT1("net", "Socket::Connect",
"remote", remote.ToString());
if (PREDICT_FALSE(!FLAGS_local_ip_for_outbound_sockets.empty())) {
RETURN_NOT_OK(BindForOutgoingConnection());
}
DCHECK_GE(fd_, 0);
int ret;
RETRY_ON_EINTR(ret, ::connect(fd_, remote.addr(), remote.addrlen()));
if (ret < 0) {
int err = errno;
return Status::NetworkError("connect(2) error", ErrnoToString(err), err);
}
return Status::OK();
}
Status Socket::GetSockError() const {
int val = 0, ret;
socklen_t val_len = sizeof(val);
DCHECK_GE(fd_, 0);
ret = ::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &val, &val_len);
if (ret) {
int err = errno;
return Status::NetworkError("getsockopt(SO_ERROR) failed", ErrnoToString(err), err);
}
if (val != 0) {
return Status::NetworkError(ErrnoToString(val), Slice(), val);
}
return Status::OK();
}
Status Socket::Write(const uint8_t *buf, int32_t amt, int32_t *nwritten) {
if (amt <= 0) {
return Status::NetworkError(
StringPrintf("invalid send of %" PRId32 " bytes",
amt), Slice(), EINVAL);
}
DCHECK_GE(fd_, 0);
int res;
RETRY_ON_EINTR(res, ::send(fd_, buf, amt, MSG_NOSIGNAL));
if (res < 0) {
int err = errno;
return Status::NetworkError("write error", ErrnoToString(err), err);
}
*nwritten = res;
return Status::OK();
}
Status Socket::Writev(const struct ::iovec *iov, int iov_len,
int64_t *nwritten) {
if (PREDICT_FALSE(iov_len <= 0)) {
return Status::NetworkError(
StringPrintf("writev: invalid io vector length of %d",
iov_len),
Slice(), EINVAL);
}
DCHECK_GE(fd_, 0);
struct msghdr msg;
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_iov = const_cast<iovec *>(iov);
msg.msg_iovlen = iov_len;
ssize_t res;
RETRY_ON_EINTR(res, ::sendmsg(fd_, &msg, MSG_NOSIGNAL));
if (PREDICT_FALSE(res < 0)) {
int err = errno;
return Status::NetworkError("sendmsg error", ErrnoToString(err), err);
}
*nwritten = res;
return Status::OK();
}
// Mostly follows writen() from Stevens (2004) or Kerrisk (2010).
Status Socket::BlockingWrite(const uint8_t *buf, size_t buflen, size_t *nwritten,
const MonoTime& deadline) {
DCHECK_LE(buflen, std::numeric_limits<int32_t>::max()) << "Writes > INT32_MAX not supported";
DCHECK(nwritten);
size_t tot_written = 0;
while (tot_written < buflen) {
int32_t inc_num_written = 0;
int32_t num_to_write = buflen - tot_written;
MonoDelta timeout = deadline - MonoTime::Now();
if (PREDICT_FALSE(timeout.ToNanoseconds() <= 0)) {
return Status::TimedOut(Substitute("sent $0 of $1 requested bytes",
tot_written, buflen));
}
RETURN_NOT_OK(SetSendTimeout(timeout));
Status s = Write(buf, num_to_write, &inc_num_written);
tot_written += inc_num_written;
buf += inc_num_written;
*nwritten = tot_written;
if (PREDICT_FALSE(!s.ok())) {
// Continue silently when the syscall is interrupted.
if (s.posix_code() == EINTR) {
continue;
}
if (s.posix_code() == EAGAIN) {
return Status::TimedOut(Substitute("sent $0 of $1 requested bytes",
tot_written, buflen));
}
return s.CloneAndPrepend("BlockingWrite error");
}
if (PREDICT_FALSE(inc_num_written == 0)) {
// Shouldn't happen on Linux with a blocking socket. Maybe other Unices.
break;
}
}
if (tot_written < buflen) {
return Status::IOError("Wrote zero bytes on a BlockingWrite() call",
StringPrintf("Transferred %zu of %zu bytes", tot_written, buflen));
}
return Status::OK();
}
Status Socket::Recv(uint8_t *buf, int32_t amt, int32_t *nread) {
if (amt <= 0) {
return Status::NetworkError(
StringPrintf("invalid recv of %d bytes", amt), Slice(), EINVAL);
}
// The recv() call can return fewer than the requested number of bytes.
// Especially when 'amt' is small, this is very unlikely to happen in
// the context of unit tests. So, we provide an injection hook which
// simulates the same behavior.
if (PREDICT_FALSE(FLAGS_socket_inject_short_recvs && amt > 1)) {
Random r(GetRandomSeed32());
amt = 1 + r.Uniform(amt - 1);
}
DCHECK_GE(fd_, 0);
int res;
RETRY_ON_EINTR(res, recv(fd_, buf, amt, 0));
if (res <= 0) {
Sockaddr remote;
Status get_addr_status = GetPeerAddress(&remote);
string remote_str = get_addr_status.ok() ? remote.ToString() : "unknown peer";
if (res == 0) {
string error_message = Substitute("recv got EOF from $0", remote_str);
return Status::NetworkError(error_message, Slice(), ESHUTDOWN);
}
int err = errno;
string error_message = Substitute("recv error from $0", remote_str);
return Status::NetworkError(error_message, ErrnoToString(err), err);
}
*nread = res;
return Status::OK();
}
// Mostly follows readn() from Stevens (2004) or Kerrisk (2010).
// One place where we deviate: we consider EOF a failure if < amt bytes are read.
Status Socket::BlockingRecv(uint8_t *buf, size_t amt, size_t *nread, const MonoTime& deadline) {
DCHECK_LE(amt, std::numeric_limits<int32_t>::max()) << "Reads > INT32_MAX not supported";
DCHECK(nread);
size_t tot_read = 0;
while (tot_read < amt) {
int32_t inc_num_read = 0;
int32_t num_to_read = amt - tot_read;
MonoDelta timeout = deadline - MonoTime::Now();
if (PREDICT_FALSE(timeout.ToNanoseconds() <= 0)) {
return Status::TimedOut(Substitute("received $0 of $1 requested bytes",
tot_read, amt));
}
RETURN_NOT_OK(SetRecvTimeout(timeout));
Status s = Recv(buf, num_to_read, &inc_num_read);
tot_read += inc_num_read;
buf += inc_num_read;
*nread = tot_read;
if (PREDICT_FALSE(!s.ok())) {
// Continue silently when the syscall is interrupted.
if (s.posix_code() == EINTR) {
continue;
}
if (s.posix_code() == EAGAIN) {
return Status::TimedOut(Substitute("received $0 of $1 requested bytes",
tot_read, amt));
}
return s.CloneAndPrepend("BlockingRecv error");
}
if (PREDICT_FALSE(inc_num_read == 0)) {
// EOF.
break;
}
}
if (PREDICT_FALSE(tot_read < amt)) {
return Status::IOError("Read zero bytes on a blocking Recv() call",
StringPrintf("Transferred %zu of %zu bytes", tot_read, amt));
}
return Status::OK();
}
Status Socket::SetTimeout(int opt, const char* optname, const MonoDelta& timeout) {
if (PREDICT_FALSE(timeout.ToNanoseconds() < 0)) {
return Status::InvalidArgument("Timeout specified as negative to SetTimeout",
timeout.ToString());
}
struct timeval tv;
timeout.ToTimeVal(&tv);
RETURN_NOT_OK_PREPEND(SetSockOpt(SOL_SOCKET, opt, tv),
Substitute("failed to set socket option $0 to $1",
optname, timeout.ToString()));
return Status::OK();
}
Status Socket::SetTcpKeepAlive(int idle_time_s, int retry_time_s, int num_retries) {
#if defined(__linux__)
static const char* const err_string = "failed to set socket option $0 to $1";
DCHECK_GT(idle_time_s, 0);
RETURN_NOT_OK_PREPEND(SetSockOpt(IPPROTO_TCP, TCP_KEEPIDLE, idle_time_s),
Substitute(err_string, "TCP_KEEPIDLE", idle_time_s));
DCHECK_GT(retry_time_s, 0);
RETURN_NOT_OK_PREPEND(SetSockOpt(IPPROTO_TCP, TCP_KEEPINTVL, retry_time_s),
Substitute(err_string, "TCP_KEEPINTVL", retry_time_s));
DCHECK_GT(num_retries, 0);
RETURN_NOT_OK_PREPEND(SetSockOpt(IPPROTO_TCP, TCP_KEEPCNT, num_retries),
Substitute(err_string, "TCP_KEEPCNT", num_retries));
RETURN_NOT_OK_PREPEND(SetSockOpt(SOL_SOCKET, SO_KEEPALIVE, 1),
"failed to enable TCP KeepAlive socket option");
#endif
return Status::OK();
}
template<typename T>
Status Socket::SetSockOpt(int level, int option, const T& value) {
if (::setsockopt(fd_, level, option, &value, sizeof(T)) == -1) {
int err = errno;
return Status::NetworkError(ErrnoToString(err), Slice(), err);
}
return Status::OK();
}
} // namespace kudu