xplat/FlipperWebSocket/WebSocketTLSClient.cpp (254 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #ifdef FB_SONARKIT_ENABLED #include "WebSocketTLSClient.h" #include <Flipper/ConnectionContextStore.h> #include <Flipper/FlipperTransportTypes.h> #include <Flipper/FlipperURLSerializer.h> #include <Flipper/Log.h> #include <folly/String.h> #include <folly/futures/Future.h> #include <folly/io/async/AsyncSocketException.h> #include <folly/io/async/SSLContext.h> #include <folly/json.h> #include <websocketpp/common/memory.hpp> #include <websocketpp/common/thread.hpp> #include <websocketpp/config/asio.hpp> #include <cctype> #include <iomanip> #include <sstream> #include <stdexcept> #include <string> #include <thread> namespace facebook { namespace flipper { WebSocketTLSClient::WebSocketTLSClient( FlipperConnectionEndpoint endpoint, std::unique_ptr<FlipperSocketBasePayload> payload, folly::EventBase* eventBase) : WebSocketTLSClient( std::move(endpoint), std::move(payload), eventBase, nullptr) {} WebSocketTLSClient::WebSocketTLSClient( FlipperConnectionEndpoint endpoint, std::unique_ptr<FlipperSocketBasePayload> payload, folly::EventBase* eventBase, ConnectionContextStore* connectionContextStore) : BaseClient( std::move(endpoint), std::move(payload), eventBase, connectionContextStore) { status_ = Status::Unconnected; socket_.clear_access_channels(websocketpp::log::alevel::all); socket_.clear_error_channels(websocketpp::log::elevel::all); socket_.init_asio(); socket_.start_perpetual(); thread_ = websocketpp::lib::make_shared<websocketpp::lib::thread>( &SocketTLSClient::run, &socket_); } WebSocketTLSClient::~WebSocketTLSClient() { disconnect(); } bool WebSocketTLSClient::connect(FlipperConnectionManager* manager) { if (status_ != Status::Unconnected) { return false; } status_ = Status::Connecting; std::string connectionURL = endpoint_.secure ? "wss://" : "ws://"; connectionURL += endpoint_.host; connectionURL += ":"; connectionURL += std::to_string(endpoint_.port); auto serializer = URLSerializer{}; payload_->serialize(serializer); auto payload = serializer.serialize(); if (payload.size()) { connectionURL += "/?"; connectionURL += payload; } socket_.set_tls_init_handler(bind( &WebSocketTLSClient::onTLSInit, this, endpoint_.host.c_str(), websocketpp::lib::placeholders::_1)); auto uri = websocketpp::lib::make_shared<websocketpp::uri>(connectionURL); websocketpp::lib::error_code ec; connection_ = socket_.get_connection(uri, ec); if (ec) { status_ = Status::Failed; return false; } handle_ = connection_->get_handle(); connection_->set_open_handler(websocketpp::lib::bind( &WebSocketTLSClient::onOpen, this, &socket_, websocketpp::lib::placeholders::_1)); connection_->set_message_handler(websocketpp::lib::bind( &WebSocketTLSClient::onMessage, this, &socket_, websocketpp::lib::placeholders::_1, websocketpp::lib::placeholders::_2)); connection_->set_fail_handler(websocketpp::lib::bind( &WebSocketTLSClient::onFail, this, &socket_, websocketpp::lib::placeholders::_1)); connection_->set_close_handler(websocketpp::lib::bind( &WebSocketTLSClient::onClose, this, &socket_, websocketpp::lib::placeholders::_1)); auto connected = connected_.get_future(); socket_.connect(connection_); auto state = connected.wait_for(std::chrono::seconds(10)); if (state == std::future_status::ready) { return connected.get(); } disconnect(); return false; } void WebSocketTLSClient::disconnect() { socket_.stop_perpetual(); if (status_ == Status::Connecting || status_ == Status::Open || status_ == Status::Failed) { websocketpp::lib::error_code ec; socket_.close(handle_, websocketpp::close::status::going_away, "", ec); } socket_.stop(); status_ = Status::Closed; if (thread_ && thread_->joinable()) { thread_->join(); } thread_ = nullptr; eventBase_->add( [eventHandler = eventHandler_]() { eventHandler(SocketEvent::CLOSE); }); } void WebSocketTLSClient::send( const folly::dynamic& message, SocketSendHandler completion) { std::string json = folly::toJson(message); send(json, std::move(completion)); } void WebSocketTLSClient::send( const std::string& message, SocketSendHandler completion) { websocketpp::lib::error_code ec; socket_.send( handle_, &message[0], message.size(), websocketpp::frame::opcode::text, ec); completion(); } /** Only ever used for insecure connections to receive the device_id from a signCertificate request. If the intended usage ever changes, then a better approach needs to be put in place. */ void WebSocketTLSClient::sendExpectResponse( const std::string& message, SocketSendExpectResponseHandler completion) { connection_->set_message_handler( [completion, eventBase = eventBase_]( websocketpp::connection_hdl hdl, SocketTLSClient::message_ptr msg) { const std::string& payload = msg->get_payload(); eventBase->add([completion, payload] { completion(payload, false); }); }); websocketpp::lib::error_code ec; socket_.send( handle_, &message[0], message.size(), websocketpp::frame::opcode::text, ec); if (ec) { auto reason = ec.message(); completion(reason, true); } } void WebSocketTLSClient::onOpen( SocketTLSClient* c, websocketpp::connection_hdl hdl) { if (status_ == Status::Connecting) { connected_.set_value(true); } status_ = Status::Initializing; eventBase_->add( [eventHandler = eventHandler_]() { eventHandler(SocketEvent::OPEN); }); } void WebSocketTLSClient::onMessage( SocketTLSClient* c, websocketpp::connection_hdl hdl, SocketTLSClient::message_ptr msg) { const std::string& payload = msg->get_payload(); if (messageHandler_) { eventBase_->add([payload, messageHandler = messageHandler_]() { messageHandler(payload); }); } } void WebSocketTLSClient::onFail( SocketTLSClient* c, websocketpp::connection_hdl hdl) { SocketTLSClient::connection_ptr con = c->get_con_from_hdl(hdl); auto server = con->get_response_header("Server"); auto reason = con->get_ec().message(); auto sslError = (reason.find("TLS handshake failed") != std::string::npos || reason.find("Generic TLS related error") != std::string::npos); if (status_ == Status::Connecting) { if (sslError) { try { connected_.set_exception( std::make_exception_ptr(folly::AsyncSocketException( folly::AsyncSocketException::SSL_ERROR, "SSL handshake failed"))); } catch (...) { // set_exception() may throw an exception // In that case, just set the value to false. connected_.set_value(false); } } else { connected_.set_value(false); } } status_ = Status::Failed; eventBase_->add([eventHandler = eventHandler_, sslError]() { if (sslError) { eventHandler(SocketEvent::SSL_ERROR); } else { eventHandler(SocketEvent::ERROR); } }); } void WebSocketTLSClient::onClose( SocketTLSClient* c, websocketpp::connection_hdl hdl) { status_ = Status::Closed; eventBase_->add( [eventHandler = eventHandler_]() { eventHandler(SocketEvent::CLOSE); }); } SocketTLSContext WebSocketTLSClient::onTLSInit( const char* hostname, websocketpp::connection_hdl) { namespace asio = websocketpp::lib::asio; SocketTLSContext ctx = websocketpp::lib::make_shared<asio::ssl::context>( asio::ssl::context::sslv23); ctx->set_options( asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use); ctx->set_verify_mode(asio::ssl::verify_peer); ctx->load_verify_file(connectionContextStore_->getPath( ConnectionContextStore::StoreItem::FLIPPER_CA)); asio::error_code error; ctx->use_certificate_file( connectionContextStore_->getPath( ConnectionContextStore::StoreItem::CLIENT_CERT), asio::ssl::context::pem, error); ctx->use_private_key_file( connectionContextStore_->getPath( ConnectionContextStore::StoreItem::PRIVATE_KEY), asio::ssl::context::pem, error); return ctx; } } // namespace flipper } // namespace facebook #endif