thrift/lib/cpp2/server/Cpp2Worker.h (249 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include <chrono> #include <memory> #include <optional> #include <string_view> #include <unordered_set> #include <variant> #include <folly/container/F14Map.h> #include <folly/io/async/AsyncServerSocket.h> #include <folly/io/async/AsyncTransport.h> #include <folly/io/async/EventBase.h> #include <folly/io/async/EventHandler.h> #include <folly/io/async/HHWheelTimer.h> #include <folly/net/NetworkSocket.h> #include <thrift/lib/cpp/async/TAsyncSSLSocket.h> #include <thrift/lib/cpp2/security/FizzPeeker.h> #include <thrift/lib/cpp2/server/IOWorkerContext.h> #include <thrift/lib/cpp2/server/MemoryTracker.h> #include <thrift/lib/cpp2/server/RequestsRegistry.h> #include <thrift/lib/cpp2/server/ThriftServer.h> #include <thrift/lib/cpp2/server/peeking/TLSHelper.h> #include <wangle/acceptor/Acceptor.h> #include <wangle/acceptor/ConnectionManager.h> #include <wangle/acceptor/PeekingAcceptorHandshakeHelper.h> namespace apache { namespace thrift { // Forward declaration of classes class Cpp2Connection; class ThriftServer; /** * Cpp2Worker drives the actual I/O for ThriftServer connections. * * The ThriftServer itself accepts incoming connections, then hands off each * connection to a Cpp2Worker running in another thread. There should * typically be around one Cpp2Worker thread per core. */ class Cpp2Worker : public IOWorkerContext, public wangle::Acceptor, private wangle::PeekingAcceptorHandshakeHelper::PeekCallback, public std::enable_shared_from_this<Cpp2Worker> { protected: enum { kPeekCount = 9 }; struct DoNotUse {}; public: /** * Cpp2Worker is the actual server object for existing connections. * One or more of these should be created by ThriftServer (one per * CPU core is recommended). * * @param server the ThriftServer which created us. * @param serverChannel existing server channel to use, only for duplex server */ static std::shared_ptr<Cpp2Worker> create( ThriftServer* server, const std::shared_ptr<HeaderServerChannel>& serverChannel = nullptr, folly::EventBase* eventBase = nullptr, std::shared_ptr<fizz::server::CertManager> certManager = nullptr, std::shared_ptr<wangle::SSLContextManager> ctxManager = nullptr, std::shared_ptr<const fizz::server::FizzServerContext> fizzContext = nullptr) { std::shared_ptr<Cpp2Worker> worker(new Cpp2Worker(server, {})); worker->setFizzCertManager(certManager); worker->setSSLContextManager(ctxManager); worker->construct(server, serverChannel, eventBase, fizzContext); return worker; } static std::shared_ptr<Cpp2Worker> createDummy(folly::EventBase* eventBase) { std::shared_ptr<Cpp2Worker> worker(new Cpp2Worker(nullptr, {})); worker->Acceptor::init(nullptr, eventBase); worker->IOWorkerContext::init(*eventBase); return worker; } void init( folly::AsyncServerSocket* serverSocket, folly::EventBase* eventBase, wangle::SSLStats* stats, std::shared_ptr<const fizz::server::FizzServerContext> fizzContext) override { securityProtocolCtxManager_.addPeeker(this); Acceptor::init(serverSocket, eventBase, stats, fizzContext); IOWorkerContext::init(*eventBase); } /* * This is called from ThriftServer::stopDuplex * Necessary for keeping the ThriftServer alive until this Worker dies */ void stopDuplex(std::shared_ptr<ThriftServer> ts); /** * Get underlying server. * * @returns pointer to ThriftServer */ ThriftServer* getServer() const { return server_; } /** * Get a shared_ptr of this Cpp2Worker. */ std::shared_ptr<Cpp2Worker> getWorkerShared() { return shared_from_this(); } /** * SSL stats hook */ void updateSSLStats( const folly::AsyncTransport* sock, std::chrono::milliseconds acceptLatency, wangle::SSLErrorEnum error, const folly::exception_wrapper& ex) noexcept override; void handleHeader( folly::AsyncTransport::UniquePtr sock, const folly::SocketAddress* addr, const wangle::TransportInfo& tinfo); RequestsRegistry* getRequestsRegistry() const { return requestsRegistry_; } bool isStopping() const { return stopping_.load(std::memory_order_relaxed); } struct ActiveRequestsDecrement { void operator()(Cpp2Worker* worker) { if (--worker->activeRequests_ == 0 && worker->isStopping()) { worker->stopBaton_.post(); } } }; using ActiveRequestsGuard = std::unique_ptr<Cpp2Worker, ActiveRequestsDecrement>; ActiveRequestsGuard getActiveRequestsGuard(); class PerServiceMetadata { public: explicit PerServiceMetadata( AsyncProcessorFactory& processorFactory, AsyncProcessorFactory::CreateMethodMetadataResult&& methods) : processorFactory_(processorFactory), methods_(std::move(methods)) {} /** * AsyncProcessorFactory::createMethodMetadata is not implemented. */ using MetadataNotImplemented = std::monostate; /** * The service metadata contained an entry for the provided method name. * Otherwise, if the metadata is WildcardMethodMetadataMap, then this is a * reference to a WildcardMethodMetadata object. * * This aligns with the contracts of MethodMetadataMap and * WildcardMethodMetadataMap. */ struct MetadataFound { const AsyncProcessorFactory::MethodMetadata& metadata; }; /** * The service metadata did not contain an entry for the provided method * name. This should result in an unknown method error. */ struct MetadataNotFound {}; /** * The result type of findMethod() below. */ using FindMethodResult = std::variant<MetadataNotImplemented, MetadataFound, MetadataNotFound>; /** * Looks up the provided method name in the metadata map. * * This returns a valid metadata object per the contract established by * AsyncProcessorFactory::createMethodMetadata. * * This returns MetadataNotFound iff no valid metadata exists. That means * that an unknown method error should be sent. * * This returns MetadataNotImplemented iff the service does not support the * createMethodMetadata() API. */ FindMethodResult findMethod(std::string_view methodName) const; /** * Extracts the base request context from the service based on the result of * findMethod(). * This returns nullptr iff no metadata was found or createMethodMetadata() * is not implemented. */ std::shared_ptr<folly::RequestContext> getBaseContextForRequest( const FindMethodResult&) const; private: AsyncProcessorFactory& processorFactory_; AsyncProcessorFactory::CreateMethodMetadataResult methods_; }; /** * Gets the per-IO-thread metadata stored per-service. The metadata is lazily * created and the same object is returned for subsequent calls that pass the * same service. */ PerServiceMetadata& getMetadataForService( AsyncProcessorFactory& processorFactory) const { getEventBase()->dcheckIsInEventBaseThread(); if (auto metadata = folly::get_ptr(perServiceMetadata_, &processorFactory)) { return *metadata; } auto [metadata, _] = perServiceMetadata_.emplace( &processorFactory, PerServiceMetadata{ processorFactory, processorFactory.createMethodMetadata()}); return metadata->second; } protected: Cpp2Worker( ThriftServer* server, DoNotUse /* ignored, never call constructor directly */) : Acceptor( server ? server->getServerSocketConfig() : wangle::ServerSocketConfig()), wangle::PeekingAcceptorHandshakeHelper::PeekCallback(kPeekCount), server_(server), activeRequests_(0) { if (server) { // Leave enough headroom to close connections ungracefully before the // worker join timeout expires. constexpr auto kGracefulTimeoutHeadroom = std::chrono::milliseconds{500}; setGracefulShutdownTimeout(std::max( server->getWorkersJoinTimeout() - kGracefulTimeoutHeadroom, std::chrono::milliseconds::zero())); } } void construct( ThriftServer* server, const std::shared_ptr<HeaderServerChannel>& serverChannel, folly::EventBase* eventBase, std::shared_ptr<const fizz::server::FizzServerContext> fizzContext) { auto observer = std::dynamic_pointer_cast<folly::EventBaseObserver>( server_->getObserverShared()); if (serverChannel) { eventBase = serverChannel->getEventBase(); } else if (!eventBase) { eventBase = folly::EventBaseManager::get()->getEventBase(); } init(nullptr, eventBase, nullptr, fizzContext); initRequestsRegistry(); if (serverChannel) { // duplex useExistingChannel(serverChannel); } if (observer) { eventBase->add([eventBase, observer = std::move(observer)] { eventBase->setObserver(observer); }); } // We distribute the memory limit averaged out over all IO workers. This // avoids the need to synchronize memory usage counts with other IO threads. // folly::AsyncServerSocket hands out connections to IO workers in a // round-robin manner so we should expect a roughly uniform distribution of // payload sizes. ingressMemoryTracker_ = std::make_unique<MemoryTracker>( folly::observer::makeObserver([server]() -> size_t { return **server->getIngressMemoryLimitObserver() / server->getNumIOWorkerThreads(); }), server->getMinPayloadSizeToEnforceIngressMemoryLimitObserver()); egressMemoryTracker_ = std::make_unique<MemoryTracker>( folly::observer::makeObserver([server]() -> size_t { return **server->getEgressMemoryLimitObserver() / server->getNumIOWorkerThreads(); })); } void onNewConnection( folly::AsyncTransport::UniquePtr, const folly::SocketAddress*, const std::string&, wangle::SecureTransportType, const wangle::TransportInfo&) override; virtual std::shared_ptr<folly::AsyncTransport> createThriftTransport( folly::AsyncTransport::UniquePtr); void markSocketAccepted(folly::AsyncSocket* sock); void plaintextConnectionReady( folly::AsyncSocket::UniquePtr sock, const folly::SocketAddress& clientAddr, wangle::TransportInfo& tinfo) override; void requestStop(); // returns false if timed out due to deadline bool waitForStop(std::chrono::steady_clock::time_point deadline); virtual wangle::AcceptorHandshakeHelper::UniquePtr createSSLHelper( const std::vector<uint8_t>& bytes, const folly::SocketAddress& clientAddr, std::chrono::steady_clock::time_point acceptTime, wangle::TransportInfo& tinfo); wangle::DefaultToFizzPeekingCallback* getFizzPeeker() override { return &fizzPeeker_; } MemoryTracker& getIngressMemoryTracker() { return *ingressMemoryTracker_; } MemoryTracker& getEgressMemoryTracker() { return *egressMemoryTracker_; } private: /// The mother ship. ThriftServer* server_; FizzPeeker fizzPeeker_; // For DuplexChannel case, set only during shutdown so that we can extend the // lifetime of the ThriftServer if the Worker is kept alive by some // Connections which are kept alive by in-flight requests std::shared_ptr<ThriftServer> duplexServer_; // We expect to have one processor factory per InterfaceKind. Using F14NodeMap // guarantees reference stability. mutable folly::F14NodeMap<AsyncProcessorFactory*, PerServiceMetadata> perServiceMetadata_; folly::AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket( const std::shared_ptr<folly::SSLContext>& ctx, folly::EventBase* base, int fd, const folly::SocketAddress* peerAddress) override { return folly::AsyncSSLSocket::UniquePtr( new apache::thrift::async::TAsyncSSLSocket( ctx, base, folly::NetworkSocket::fromFd(fd), true, /* set server */ true /* defer the security negotiation until sslAccept. */, peerAddress)); } /** * For a duplex Thrift server, use an existing channel */ void useExistingChannel( const std::shared_ptr<HeaderServerChannel>& serverChannel); void cancelQueuedRequests(); uint32_t activeRequests_; RequestsRegistry* requestsRegistry_; std::atomic<bool> stopping_{false}; folly::Baton<> stopBaton_; std::unique_ptr<MemoryTracker> ingressMemoryTracker_; std::unique_ptr<MemoryTracker> egressMemoryTracker_; void initRequestsRegistry(); wangle::AcceptorHandshakeHelper::UniquePtr getHelper( const std::vector<uint8_t>& bytes, const folly::SocketAddress& clientAddr, std::chrono::steady_clock::time_point acceptTime, wangle::TransportInfo& tinfo) override; bool isPlaintextAllowedOnLoopback() { return server_->isPlaintextAllowedOnLoopback(); } SSLPolicy getSSLPolicy() { return server_->getSSLPolicy(); } bool shouldPerformSSL( const std::vector<uint8_t>& bytes, const folly::SocketAddress& clientAddr); std::optional<ThriftParametersContext> getThriftParametersContext(); friend class Cpp2Connection; friend class ThriftServer; friend class RocketRoutingHandler; friend class TestRoutingHandler; }; } // namespace thrift } // namespace apache