net/rsocket/rsocket.cpp (481 lines of code) (raw):

#include "rsocket.h" #include <fcntl.h> #include <photon/common/alog-stdstring.h> #include <photon/common/alog.h> #include <photon/common/iovector.h> #include <photon/common/timeout.h> #include <photon/net/basic_socket.h> #include <photon/net/socket.h> #include <photon/thread/thread11.h> #include <rdma/rsocket.h> #include <vector> namespace photon { namespace net { struct tmp_msg_hdr : public ::msghdr { tmp_msg_hdr(const iovec* iov, int iovcnt) : ::msghdr{} { this->msg_iov = (iovec*)iov; this->msg_iovlen = iovcnt; } explicit tmp_msg_hdr(iovector_view& view) : tmp_msg_hdr(view.iov, view.iovcnt) {} operator ::msghdr*() { return this; } }; struct RSockFD { int fd = -1; explicit RSockFD(int fd) : fd(fd) { if (fd >= 0) rfcntl(fd, F_SETFL, O_NONBLOCK, 1); } RSockFD() { fd = rsocket(PF_INET, SOCK_STREAM, IPPROTO_TCP); if (fd >= 0) { rfcntl(fd, F_SETFL, O_NONBLOCK, 1); } } ~RSockFD() { if (fd >= 0) { close(); } } RSockFD(const RSockFD&) = delete; RSockFD& operator=(const RSockFD&) = delete; RSockFD(RSockFD&& o) { std::swap(fd, o.fd); } RSockFD& operator=(RSockFD&& o) { std::swap(fd, o.fd); return *this; } operator bool() const { return fd >= 0; } int close() { int ret = -1; if (fd >= 0) { ret = rclose(fd); fd = -1; } return ret; } Object* get_underlay_object(uint64_t) { return (Object*)(uint64_t)fd; } int setsockopt(int level, int option_name, const void* option_value, socklen_t option_len) { return rsetsockopt(fd, level, option_name, option_value, option_len); } int getsockopt(int level, int option_name, void* option_value, socklen_t* option_len) { return rgetsockopt(fd, level, option_name, option_value, option_len); } int do_get_name(int fd, Getter getter, EndPoint& addr) { sockaddr_storage storage; socklen_t len = storage.get_max_socklen(); int ret = getter(fd, storage.get_sockaddr(), &len); if (ret < 0 || len > storage.get_max_socklen()) return -1; addr = storage.to_endpoint(); return 0; } int getsockname(photon::net::EndPoint& addr) { sockaddr_storage storage; socklen_t addrlen = sizeof(storage.store); auto ret = rgetsockname(fd, (sockaddr*)&storage.store, &addrlen); if (ret == 0) addr = storage.to_endpoint(); return ret; } int getpeername(photon::net::EndPoint& addr) { sockaddr_storage storage; socklen_t addrlen = sizeof(storage.store); auto ret = rgetpeername(fd, (sockaddr*)&storage.store, &addrlen); if (ret == 0) addr = storage.to_endpoint(); return ret; } int do_poll(int events, uint64_t timeout) { int ret; struct pollfd fds; fds.fd = fd; fds.events = events; Timeout tmo(timeout); do { ret = rpoll(&fds, 1, 0); photon::thread_yield(); } while (!ret && tmo.timeout()); if (ret == 0) { errno = ETIMEDOUT; return -1; } else { return ret == 1 ? (fds.revents & (POLLERR | POLLHUP | POLLIN | POLLOUT)) : ret; } } template <int EVENT, typename FUNC, typename... ARGS> ssize_t do_io_action(uint64_t t, FUNC f, ARGS... args) { Timeout tmo(t); do { auto ret = do_poll(EVENT | POLLERR | POLLHUP, tmo.timeout()); if (ret < 0) return ret; if (ret & EVENT) { auto x = f(fd, args...); if (x < 0) { ERRNO err; if (err.no == EAGAIN || err.no == EWOULDBLOCK) { continue; } } return x; } } while (tmo.timeout()); return 0; } template <int EVENT, typename FUNC, typename... ARGS> ssize_t do_io_fully_action(uint64_t t, FUNC f, size_t n, ARGS... args) { Timeout tmo(t); ssize_t x = 0; while (x < (ssize_t)n) { auto ret = do_io_action<EVENT>(tmo.timeout(), f, args...); if (ret < 0) return ret; if (ret == 0) return x; x += ret; } return x; } ssize_t recv(uint64_t timeout, void* buf, size_t count, int flags = 0) { return do_io_action<POLLIN>(timeout, rrecv, buf, count, flags); } ssize_t recv(uint64_t timeout, const struct iovec* iov, int iovcnt, int flags = 0) { tmp_msg_hdr msg(iov, iovcnt); return do_io_action<POLLIN>(timeout, rrecvmsg, msg, flags); } ssize_t read(uint64_t timeout, void* buf, size_t count) { return do_io_fully_action<POLLIN>(timeout, rrecv, count, buf, count, 0); } ssize_t readv(uint64_t timeout, const struct iovec* iov, int iovcnt) { iovector_view viov{(iovec*)iov, iovcnt}; tmp_msg_hdr msg(viov); return do_io_fully_action<POLLIN>(timeout, rrecvmsg, viov.sum(), msg, 0); } ssize_t send(uint64_t timeout, const void* buf, size_t count, int flags = 0) { return do_io_action<POLLOUT>(timeout, rsend, buf, count, flags); } ssize_t send(uint64_t timeout, const struct iovec* iov, int iovcnt, int flags = 0) { tmp_msg_hdr msg(iov, iovcnt); return do_io_action<POLLOUT>(timeout, rsendmsg, msg, flags); } ssize_t write(uint64_t timeout, const void* buf, size_t count) { return do_io_fully_action<POLLOUT>(timeout, rsend, count, buf, count, 0); } ssize_t writev(uint64_t timeout, const struct iovec* iov, int iovcnt) { iovector_view viov{(iovec*)iov, iovcnt}; tmp_msg_hdr msg(viov); return do_io_fully_action<POLLOUT>(timeout, rsendmsg, viov.sum(), msg, 0); } int shutdown(ShutdownHow how) { return rshutdown(fd, static_cast<int>(how)); } int bind(const photon::net::EndPoint& addr) { sockaddr_storage addr_storage(addr); return rbind(fd, addr_storage.get_sockaddr(), addr_storage.get_socklen()); } int listen(int backlog) { return rlisten(fd, backlog); } template <typename Func> RSockFD accept(uint64_t timeout, photon::net::EndPoint* remote_endpoint, Func&& interrupt) { sockaddr_storage addr; socklen_t socklen = sizeof(addr.store); Timeout tmo(timeout); int ret; do { ret = raccept(fd, (sockaddr*)&addr.store, &socklen); if (ret < 0) { if (errno == EWOULDBLOCK || errno == EAGAIN) { if (interrupt()) { errno = EINTR; break; } photon::thread_yield(); } else { break; } } } while (ret < 0 && tmo.timeout()); if (ret < 0) return RSockFD(-1); if (remote_endpoint) { *remote_endpoint = addr.to_endpoint(); } return RSockFD(ret); } int connect(uint64_t timeout, const photon::net::EndPoint& remote_endpoint) { Timeout tmo(timeout); sockaddr_storage addr(remote_endpoint); auto ret = rconnect(fd, (sockaddr*)&addr.store, addr.get_socklen()); if (ret < 0) { if (errno == EINPROGRESS) { auto r = do_poll(POLLOUT, tmo.timeout()); if (r & POLLOUT) return 0; else if (r < 0) return r; } else { return ret; } } return ret; } }; class RSocketStream : public photon::net::ISocketStream { public: RSockFD rfd; uint64_t m_timeout = -1UL; explicit RSocketStream(RSockFD&& fd) : rfd(std::move(fd)) {} uint64_t timeout() const override { return m_timeout; } void timeout(uint64_t tm) override { m_timeout = tm; } Object* get_underlay_object(uint64_t) override { return rfd.get_underlay_object(0); } int setsockopt(int level, int option_name, const void* option_value, socklen_t option_len) override { return rfd.setsockopt(level, option_name, option_value, option_len); } int getsockopt(int level, int option_name, void* option_value, socklen_t* option_len) override { return rfd.getsockopt(level, option_name, option_value, option_len); } int getsockname(photon::net::EndPoint& addr) override { return rfd.getsockname(addr); } int getpeername(photon::net::EndPoint& addr) override { return rfd.getpeername(addr); } int getsockname(char* path, size_t count) override { errno = ENOSYS; return -1; } int getpeername(char* path, size_t count) override { errno = ENOSYS; return -1; } ssize_t sendfile(int in_fd, off_t offset, size_t count) override { errno = ENOSYS; return -1; } int close() override { return rfd.close(); } int shutdown(ShutdownHow how) override { return rfd.shutdown(how); } ssize_t recv(void* buf, size_t count, int flags = 0) override { return rfd.recv(m_timeout, buf, count, flags); } ssize_t recv(const struct iovec* iov, int iovcnt, int flags = 0) override { return rfd.recv(m_timeout, iov, iovcnt, flags); } ssize_t send(const void* buf, size_t count, int flags = 0) override { return rfd.send(m_timeout, buf, count, flags); } ssize_t send(const struct iovec* iov, int iovcnt, int flags = 0) override { return rfd.send(m_timeout, iov, iovcnt, flags); } ssize_t read(void* buf, size_t count) override { return rfd.read(m_timeout, buf, count); } ssize_t readv(const struct iovec* iov, int iovcnt) override { return rfd.readv(m_timeout, iov, iovcnt); } ssize_t write(const void* buf, size_t count) override { return rfd.write(m_timeout, buf, count); } ssize_t writev(const struct iovec* iov, int iovcnt) override { return rfd.writev(m_timeout, iov, iovcnt); } }; struct SocketOpt { int level; int opt_name; void* opt_val; socklen_t opt_len; }; class SockOptBuffer : public std::vector<SocketOpt> { protected: static constexpr uint64_t BUFFERSIZE = 4096; char buffer[BUFFERSIZE]; char* ptr = buffer; public: int put_opt(int level, int name, const void* val, socklen_t len) { if (ptr + len >= buffer + BUFFERSIZE) { return -1; } memcpy(ptr, val, len); push_back(SocketOpt{level, name, ptr, len}); ptr += len; return 0; } int get_opt(int level, int name, void* val, socklen_t* len) { for (auto& x : *this) if (level == x.level && name == x.opt_name && *len >= x.opt_len) return memcpy(val, x.opt_val, *len = x.opt_len), 0; return -1; } virtual int setsockopt(RSockFD& rfd) { for (auto& opt : *this) { if (rfd.setsockopt(opt.level, opt.opt_name, opt.opt_val, opt.opt_len) != 0) { LOG_ERRNO_RETURN(EINVAL, -1, "Failed to setsockopt ", VALUE(opt.level), VALUE(opt.opt_name), VALUE(opt.opt_val)); } } return 0; } }; class RSocketClient : public photon::net::ISocketClient { public: uint64_t m_timeout = -1UL; SockOptBuffer m_opts; photon::net::ISocketStream* connect(const char* path, size_t count = 0) override { errno = ENOSYS; return nullptr; } photon::net::ISocketStream* connect( const photon::net::EndPoint& remote, const photon::net::EndPoint* local) override { // do something RSockFD sock; if (m_opts.setsockopt(sock) < 0) return nullptr; if (local != nullptr) { if (sock.bind(*local) < 0) { LOG_ERROR_RETURN(0, nullptr, "failed to bind local port"); } } auto ret = sock.connect(m_timeout, remote); if (ret < 0) LOG_ERROR_RETURN(0, nullptr, "failed to connect to remote"); return new RSocketStream(std::move(sock)); } ~RSocketClient() override {} Object* get_underlay_object(uint64_t) override { errno = ENOSYS; return nullptr; } int setsockopt(int level, int option_name, const void* option_value, socklen_t option_len) override { return m_opts.put_opt(level, option_name, option_value, option_len); } int getsockopt(int level, int option_name, void* option_value, socklen_t* option_len) override { return m_opts.get_opt(level, option_name, option_value, option_len); } uint64_t timeout() const override { return m_timeout; } void timeout(uint64_t tm) override { m_timeout = tm; } }; class RSocketServer : public photon::net::ISocketServer { public: RSockFD rfd; bool m_block_serv = false; bool waiting = false; Handler m_handler; std::atomic<photon::thread*> workth{}; uint64_t m_timeout = -1UL; ~RSocketServer() override { terminate(); } int bind(const EndPoint& ep) override { return rfd.bind(ep); } int bind(const char* path, size_t count) override { errno = ENOSYS; return -1; } int listen(int backlog = 1024) override { return rfd.listen(backlog); } photon::net::ISocketStream* accept( photon::net::EndPoint* remote_endpoint = nullptr) override { auto fd = rfd.accept(m_timeout, remote_endpoint, [&] { return !check_running(); }); if (!fd) return nullptr; return new RSocketStream(std::move(fd)); } ISocketServer* set_handler(Handler handler) override { m_handler = handler; return this; } static void handler(Handler m_handler, photon::net::ISocketStream* sess) { m_handler(sess); delete sess; } bool check_running() { return workth.load(); } int accept_loop() { photon::thread* th = nullptr; if (!workth.compare_exchange_strong(th, photon::CURRENT)) LOG_ERROR_RETURN(EALREADY, -1, "Already listening"); while (check_running()) { waiting = true; auto connection = accept(); waiting = false; if (!check_running()) return 0; if (connection) { connection->timeout(m_timeout); photon::thread_create11(&RSocketServer::handler, m_handler, connection); } else { LOG_WARN( "KernelSocketServer: failed to accept new connections: `", ERRNO()); photon::thread_yield(); } } return 0; } int start_loop(bool block) override { if (check_running()) LOG_ERROR_RETURN(EALREADY, -1, "Already listening"); m_block_serv = block; if (block) return accept_loop(); auto loop = &RSocketServer::accept_loop; auto th = photon::thread_create((photon::thread_entry&)loop, this); thread_enable_join(th); thread_yield_to(th); return 0; } void terminate() final { photon::thread* th = workth.exchange(nullptr); if (!th) return; if (waiting) { thread_interrupt(th); if (!m_block_serv) thread_join((photon::join_handle*)th); } } Object* get_underlay_object(uint64_t) override { return rfd.get_underlay_object(0); } int setsockopt(int level, int option_name, const void* option_value, socklen_t option_len) override { return rfd.setsockopt(level, option_name, option_value, option_len); } int getsockopt(int level, int option_name, void* option_value, socklen_t* option_len) override { return rfd.getsockopt(level, option_name, option_value, option_len); } uint64_t timeout() const override { return m_timeout; } void timeout(uint64_t tm) override { m_timeout = tm; } int getsockname(photon::net::EndPoint& addr) override { return rfd.getsockname(addr); } int getpeername(photon::net::EndPoint& addr) override { return rfd.getpeername(addr); } int getsockname(char* path, size_t count) override { errno = ENOSYS; return -1; } int getpeername(char* path, size_t count) override { errno = ENOSYS; return -1; } }; extern "C" photon::net::ISocketClient* new_rsocket_client() { return new RSocketClient(); } extern "C" photon::net::ISocketServer* new_rsocket_server() { auto ret = new RSocketServer(); if (!ret->rfd) { delete ret; return nullptr; } return ret; } } // namespace net } // namespace photon