utils/include/utils/net/AsioSocketUtils.h (139 lines of code) (raw):

/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #ifndef WIN32 #include <ifaddrs.h> #endif #include <string> #include <utility> #include <tuple> #include <memory> #include "asio/ssl.hpp" #include "asio/ip/tcp.hpp" #include "AsioCoro.h" #include "utils/Hash.h" #include "utils/StringUtils.h" // for string <=> on libc++ #include "minifi-cpp/controllers/SSLContextService.h" #include "io/BaseStream.h" #include "utils/Deleters.h" #include "utils/net/Socket.h" #include "core/logging/LoggerFactory.h" namespace org::apache::nifi::minifi::utils::net { using HandshakeType = asio::ssl::stream_base::handshake_type; using TcpSocket = asio::ip::tcp::socket; using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; constexpr auto MINIFI_SSL_OPTIONS = asio::ssl::context::default_workarounds | asio::ssl::context::single_dh_use | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 | asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1; class ConnectionId { public: ConnectionId(std::string hostname, std::string port) : hostname_(std::move(hostname)), service_(std::move(port)) {} ConnectionId(const ConnectionId& connection_id) = default; ConnectionId(ConnectionId&& connection_id) = default; ConnectionId& operator=(ConnectionId&&) = default; ConnectionId& operator=(const ConnectionId&) = default; auto operator<=>(const ConnectionId&) const = default; [[nodiscard]] std::string_view getHostname() const { return hostname_; } [[nodiscard]] std::string_view getService() const { return service_; } private: std::string hostname_; std::string service_; }; template<class SocketType> asio::awaitable<std::tuple<std::error_code>> handshake(SocketType&, asio::steady_timer::duration) = delete; template<> asio::awaitable<std::tuple<std::error_code>> handshake(TcpSocket&, asio::steady_timer::duration); template<> asio::awaitable<std::tuple<std::error_code>> handshake(SslSocket& socket, asio::steady_timer::duration); asio::ssl::context getSslContext(const controllers::SSLContextService& ssl_context_service, asio::ssl::context::method ssl_context_method = asio::ssl::context::tls_client); struct SocketData { std::string host = "localhost"; int port = -1; std::shared_ptr<minifi::controllers::SSLContextService> ssl_context_service; }; class AsioSocketConnection : public io::BaseStreamImpl { public: explicit AsioSocketConnection(SocketData socket_data); int initialize() override; size_t read(std::span<std::byte> out_buffer) override { gsl_Expects(stream_); return stream_->read(out_buffer); } size_t write(const uint8_t *in_buffer, size_t len) override { gsl_Expects(stream_); return stream_->write(in_buffer, len); } void setInterface(const std::string& local_network_interface) { local_network_interface_ = local_network_interface; } private: #ifndef WIN32 template<typename SocketType> void bindToLocalInterfaceIfSpecified(SocketType& socket) { if (local_network_interface_.empty()) { return; } using ifaddrs_uniq_ptr = std::unique_ptr<ifaddrs, utils::ifaddrs_deleter>; const auto if_list_ptr = []() -> ifaddrs_uniq_ptr { ifaddrs *list = nullptr; [[maybe_unused]] const auto get_ifa_success = getifaddrs(&list) == 0; assert(get_ifa_success || !list); return ifaddrs_uniq_ptr{ list }; }(); if (!if_list_ptr) { return; } const auto advance_func = [](const ifaddrs *const p) { return p->ifa_next; }; const auto predicate = [this](const ifaddrs *const item) { return item->ifa_addr && item->ifa_name && (item->ifa_addr->sa_family == AF_INET || item->ifa_addr->sa_family == AF_INET6) && item->ifa_name == local_network_interface_; }; auto item_found = [&]() -> ifaddrs* { for (auto it = if_list_ptr.get(); it; it = advance_func(it)) { if (predicate(it)) { return it; } } return nullptr; }(); if (item_found == nullptr) { logger_->log_error("Could not find specified network interface: '{}'", local_network_interface_); return; } std::string address; try { address = utils::net::sockaddr_ntop(item_found->ifa_addr); } catch(const std::exception& ex) { logger_->log_error("Error occurred while getting network interface address: '{}'", ex.what()); return; } asio::ip::tcp::endpoint local_endpoint(asio::ip::address::from_string(address), 0); asio::error_code err; socket.open(local_endpoint.protocol(), err); if (err) { logger_->log_error("Failed to open socket on network interface '{}' with the following message: '{}'", local_network_interface_, err.message()); return; } socket.bind(local_endpoint, err); if (err) { logger_->log_error("Failed to bind to network interface '{}' with the following message: '{}'", local_network_interface_, err.message()); return; } } #endif bool connectTcpSocketOverSsl(); bool connectTcpSocket(); asio::io_context io_context_; std::unique_ptr<io::BaseStream> stream_; SocketData socket_data_; std::string local_network_interface_; std::shared_ptr<core::logging::Logger> logger_{core::logging::LoggerFactory<AsioSocketConnection>::getLogger()}; }; } // namespace org::apache::nifi::minifi::utils::net template<> struct std::hash<org::apache::nifi::minifi::utils::net::ConnectionId> { size_t operator()(const org::apache::nifi::minifi::utils::net::ConnectionId& connection_id) const noexcept { return org::apache::nifi::minifi::utils::hash_combine( std::hash<std::string_view>{}(connection_id.getHostname()), std::hash<std::string_view>{}(connection_id.getService())); } }; template <typename InternetProtocol> struct fmt::formatter<asio::ip::basic_endpoint<InternetProtocol>> : fmt::ostream_formatter {}; template <> struct fmt::formatter<asio::ip::address> : fmt::ostream_formatter {};