mcrouter/lib/network/AsyncTlsToPlaintextSocket.cpp (111 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include "mcrouter/lib/network/AsyncTlsToPlaintextSocket.h" #include <chrono> #include <functional> #include <utility> #include <folly/GLog.h> #include <folly/ScopeGuard.h> #include <folly/io/SocketOptionMap.h> #include <folly/io/async/AsyncSocket.h> #include <folly/io/async/ssl/BasicTransportCertificate.h> #include <folly/portability/OpenSSL.h> #include <folly/io/async/AsyncSocket.h> #include <thrift/lib/cpp/async/TAsyncSSLSocket.h> namespace facebook { namespace memcache { class AsyncTlsToPlaintextSocket::ConnectCallback : public folly::AsyncSocket::ConnectCallback { public: ConnectCallback( AsyncTlsToPlaintextSocket& me, folly::AsyncSocket::ConnectCallback* connectCallback) : me_(me), connectCallback_(connectCallback) {} void connectSuccess() noexcept override { DestructorGuard dg{&me_}; SCOPE_EXIT { if (auto* cb = std::exchange(connectCallback_, nullptr)) { cb->connectSuccess(); } delete this; }; auto& impl = me_.impl_; auto activateSocket = [&] { impl->setSendTimeout(me_.writeTimeout_.count()); impl->setReadCB(std::exchange(me_.readCallback_, nullptr)); me_.state_ = State::CONNECTED; me_.flushWrites(); }; auto* tlsSocket = impl->getUnderlyingTransport<apache::thrift::async::TAsyncSSLSocket>(); CHECK(tlsSocket); // Save state regarding session resumption if (tlsSocket->sessionResumptionAttempted()) { me_.resumptionStatus_ = tlsSocket->getSSLSessionReused() ? SessionResumptionStatus::RESUMPTION_ATTEMPTED_AND_SUCCEEDED : SessionResumptionStatus::RESUMPTION_ATTEMPTED_AND_FAILED; } if (tlsSocket->getApplicationProtocol() != kMcSecurityTlsToPlaintextProto) { FB_LOG_EVERY_MS(ERROR, 10) << "Failed to negotiate plaintext fallback. Falling back to full TLS."; // Even if the server fails the handshake, we must be sure to drain any // pending writes activateSocket(); return; } // Save peer cert auto peerCert = folly::ssl::BasicTransportCertificate::create( tlsSocket->getPeerCertificate()); // We need to mark the SSL as shutdown here, but need to do // it quietly so no alerts are sent over the wire. // This prevents SSL thinking we are shutting down in a bad state // when AsyncSSLSocket is cleaned up, which could remove the session // from the session cache auto* ssl = const_cast<SSL*>(tlsSocket->getSSL()); SSL_set_quiet_shutdown(ssl, 1); SSL_shutdown(ssl); DCHECK_EQ(0, tlsSocket->getZeroCopyBufId()); impl.reset( new folly::AsyncSocket(&me_.evb_, tlsSocket->detachNetworkSocket())); activateSocket(); impl->getUnderlyingTransport<folly::AsyncSocket>()->setPeerCertificate( std::move(peerCert)); } void connectErr(const folly::AsyncSocketException& ex) noexcept override { DestructorGuard dg{&me_}; me_.failAllWrites(ex); if (auto* readCallback = std::exchange(me_.readCallback_, nullptr)) { readCallback->readErr(ex); } if (auto* cb = std::exchange(connectCallback_, nullptr)) { cb->connectErr(ex); } delete this; } private: AsyncTlsToPlaintextSocket& me_; folly::AsyncSocket::ConnectCallback* connectCallback_{nullptr}; }; void AsyncTlsToPlaintextSocket::connect( folly::AsyncSocket::ConnectCallback* connectCallback, const folly::SocketAddress& address, std::chrono::milliseconds connectTimeout, folly::SocketOptionMap socketOptions) { auto* const wrappedConnectCallback = new ConnectCallback(*this, connectCallback); impl_->getUnderlyingTransport<apache::thrift::async::TAsyncSSLSocket>() ->connect( wrappedConnectCallback, address, connectTimeout.count(), std::move(socketOptions)); } void AsyncTlsToPlaintextSocket::flushWrites() { while (!bufferedWrites_.empty()) { auto& bufferedWrite = bufferedWrites_.front(); auto* cb = bufferedWrite.callback; auto buf = std::move(bufferedWrite.buf); bufferedWrites_.pop_front(); impl_->writeChain(cb, std::move(buf)); } } void AsyncTlsToPlaintextSocket::failAllWrites( const folly::AsyncSocketException& ex) { while (!bufferedWrites_.empty()) { auto& bufferedWrite = bufferedWrites_.front(); auto* cb = bufferedWrite.callback; bufferedWrites_.pop_front(); cb->writeErr(0, ex); } } } // namespace memcache } // namespace facebook