thrift/lib/cpp2/server/Cpp2Worker.cpp (340 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrift/lib/cpp2/server/Cpp2Worker.h>
#include <vector>
#include <glog/logging.h>
#include <folly/Overload.h>
#include <folly/String.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/EventBaseLocal.h>
#include <folly/portability/Sockets.h>
#include <thrift/lib/cpp/async/TAsyncSSLSocket.h>
#include <thrift/lib/cpp/concurrency/Util.h>
#include <thrift/lib/cpp2/async/ResponseChannel.h>
#include <thrift/lib/cpp2/security/extensions/ThriftParametersContext.h>
#include <thrift/lib/cpp2/server/Cpp2Connection.h>
#include <thrift/lib/cpp2/server/LoggingEvent.h>
#include <thrift/lib/cpp2/server/ThriftServer.h>
#include <thrift/lib/cpp2/server/peeking/PeekingManager.h>
#include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
#include <wangle/acceptor/EvbHandshakeHelper.h>
#include <wangle/acceptor/SSLAcceptorHandshakeHelper.h>
#include <wangle/acceptor/UnencryptedAcceptorHandshakeHelper.h>
namespace apache {
namespace thrift {
namespace {
folly::LeakySingleton<folly::EventBaseLocal<RequestsRegistry>> registry;
} // namespace
void Cpp2Worker::initRequestsRegistry() {
auto* evb = getEventBase();
auto memPerReq = server_->getMaxDebugPayloadMemoryPerRequest();
auto memPerWorker = server_->getMaxDebugPayloadMemoryPerWorker();
auto maxFinished = server_->getMaxFinishedDebugPayloadsPerWorker();
std::weak_ptr<Cpp2Worker> self_weak = shared_from_this();
evb->runInEventBaseThread([=, self_weak = std::move(self_weak)]() {
if (auto self = self_weak.lock()) {
self->requestsRegistry_ = ®istry.get().try_emplace(
*evb, memPerReq, memPerWorker, maxFinished);
}
});
}
void Cpp2Worker::onNewConnection(
folly::AsyncTransport::UniquePtr sock,
const folly::SocketAddress* addr,
const std::string& nextProtocolName,
wangle::SecureTransportType secureTransportType,
const wangle::TransportInfo& tinfo) {
// This is possible if the connection was accepted before stopListening()
// call, but handshake was finished after stopCPUWorkers() call.
if (stopping_) {
return;
}
auto* observer = server_->getObserver();
uint32_t maxConnection = server_->getMaxConnections();
if (maxConnection > 0 &&
(getConnectionManager()->getNumConnections() >=
maxConnection / server_->getNumIOWorkerThreads())) {
if (observer) {
observer->connDropped();
observer->connRejected();
}
return;
}
const auto& func = server_->getZeroCopyEnableFunc();
if (func && sock) {
sock->setZeroCopy(true);
sock->setZeroCopyEnableFunc(func);
}
// Check the security protocol
switch (secureTransportType) {
// If no security, peek into the socket to determine type
case wangle::SecureTransportType::NONE: {
new TransportPeekingManager(
shared_from_this(), *addr, tinfo, server_, std::move(sock));
break;
}
case wangle::SecureTransportType::TLS:
// Use the announced protocol to determine the correct handler
if (!nextProtocolName.empty()) {
for (auto& routingHandler : *server_->getRoutingHandlers()) {
if (routingHandler->canAcceptEncryptedConnection(nextProtocolName)) {
VLOG(4) << "Cpp2Worker: Routing encrypted connection for protocol "
<< nextProtocolName;
routingHandler->handleConnection(
getConnectionManager(),
std::move(sock),
addr,
tinfo,
shared_from_this());
return;
}
}
}
if (!getServer()->isDuplex()) {
new TransportPeekingManager(
shared_from_this(), *addr, tinfo, server_, std::move(sock));
} else {
handleHeader(std::move(sock), addr, tinfo);
}
break;
default:
LOG(ERROR) << "Unsupported Secure Transport Type";
break;
}
}
void Cpp2Worker::handleHeader(
folly::AsyncTransport::UniquePtr sock,
const folly::SocketAddress* addr,
const wangle::TransportInfo& tinfo) {
auto fd = sock->getUnderlyingTransport<folly::AsyncSocket>()
->getNetworkSocket()
.toFd();
VLOG(4) << "Cpp2Worker: Creating connection for socket " << fd;
auto thriftTransport = createThriftTransport(std::move(sock));
auto connection = std::make_shared<Cpp2Connection>(
std::move(thriftTransport), addr, shared_from_this(), nullptr);
Acceptor::addConnection(connection.get());
connection->addConnection(connection, tinfo);
connection->start();
VLOG(4) << "Cpp2Worker: created connection for socket " << fd;
auto observer = server_->getObserver();
if (observer) {
observer->activeConnections(
getConnectionManager()->getNumConnections() *
server_->getNumIOWorkerThreads());
}
}
std::shared_ptr<folly::AsyncTransport> Cpp2Worker::createThriftTransport(
folly::AsyncTransport::UniquePtr sock) {
auto fizzServer = dynamic_cast<fizz::server::AsyncFizzServer*>(sock.get());
if (fizzServer) {
auto asyncSock = sock->getUnderlyingTransport<folly::AsyncSocket>();
if (asyncSock) {
markSocketAccepted(asyncSock);
}
// give up ownership
sock.release();
return std::shared_ptr<fizz::server::AsyncFizzServer>(
fizzServer, fizz::server::AsyncFizzServer::Destructor());
}
folly::AsyncSocket* tsock =
sock->getUnderlyingTransport<folly::AsyncSocket>();
CHECK(tsock);
markSocketAccepted(tsock);
// use custom deleter for std::shared_ptr<folly::AsyncTransport> to allow
// socket transfer from header to rocket (if enabled by ThriftFlags)
return apache::thrift::transport::detail::convertToShared(std::move(sock));
}
void Cpp2Worker::markSocketAccepted(folly::AsyncSocket* sock) {
sock->setShutdownSocketSet(server_->wShutdownSocketSet_);
}
void Cpp2Worker::plaintextConnectionReady(
folly::AsyncSocket::UniquePtr sock,
const folly::SocketAddress& clientAddr,
wangle::TransportInfo& tinfo) {
sock->setShutdownSocketSet(server_->wShutdownSocketSet_);
new CheckTLSPeekingManager(
shared_from_this(),
clientAddr,
tinfo,
server_,
std::move(sock),
server_->getObserverShared());
}
void Cpp2Worker::useExistingChannel(
const std::shared_ptr<HeaderServerChannel>& serverChannel) {
folly::SocketAddress address;
auto conn = std::make_shared<Cpp2Connection>(
nullptr, &address, shared_from_this(), serverChannel);
Acceptor::getConnectionManager()->addConnection(conn.get(), false);
conn->addConnection(conn);
conn->start();
}
void Cpp2Worker::stopDuplex(std::shared_ptr<ThriftServer> myServer) {
// They better have given us the correct ThriftServer
DCHECK(server_ == myServer.get());
// This does not really fully drain everything but at least
// prevents the connections from accepting new requests
wangle::Acceptor::drainAllConnections();
// Capture a shared_ptr to our ThriftServer making sure it will outlive us
// Otherwise our raw pointer to it (server_) will be jeopardized.
duplexServer_ = myServer;
}
void Cpp2Worker::updateSSLStats(
const folly::AsyncTransport* sock,
std::chrono::milliseconds /* acceptLatency */,
wangle::SSLErrorEnum error,
const folly::exception_wrapper& /*ex*/) noexcept {
if (!sock) {
return;
}
auto observer = getServer()->getObserver();
if (!observer) {
return;
}
auto fizz = sock->getUnderlyingTransport<fizz::server::AsyncFizzServer>();
if (fizz) {
if (sock->good() && error == wangle::SSLErrorEnum::NO_ERROR) {
observer->tlsComplete();
auto pskType = fizz->getState().pskType();
if (pskType && *pskType == fizz::PskType::Resumption) {
observer->tlsResumption();
}
if (fizz->getPeerCertificate()) {
observer->tlsWithClientCert();
}
} else {
observer->tlsError();
}
} else {
auto socket = sock->getUnderlyingTransport<folly::AsyncSSLSocket>();
if (!socket) {
return;
}
if (socket->good() && error == wangle::SSLErrorEnum::NO_ERROR) {
observer->tlsComplete();
if (socket->getSSLSessionReused()) {
observer->tlsResumption();
}
if (socket->getPeerCertificate()) {
observer->tlsWithClientCert();
}
} else {
observer->tlsError();
}
}
}
wangle::AcceptorHandshakeHelper::UniquePtr Cpp2Worker::createSSLHelper(
const std::vector<uint8_t>& bytes,
const folly::SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime,
wangle::TransportInfo& tInfo) {
if (accConfig_.fizzConfig.enableFizz) {
if (auto parametersContext = getThriftParametersContext()) {
fizzPeeker_.setThriftParametersContext(
folly::copy_to_shared_ptr(*parametersContext));
}
return getFizzPeeker()->getHelper(bytes, clientAddr, acceptTime, tInfo);
}
return defaultPeekingCallback_.getHelper(
bytes, clientAddr, acceptTime, tInfo);
}
bool Cpp2Worker::shouldPerformSSL(
const std::vector<uint8_t>& bytes, const folly::SocketAddress& clientAddr) {
auto sslPolicy = getSSLPolicy();
if (sslPolicy == SSLPolicy::REQUIRED) {
if (isPlaintextAllowedOnLoopback()) {
// loopback clients may still be sending TLS so we need to ensure that
// it doesn't appear that way in addition to verifying it's loopback.
return !(
clientAddr.isLoopbackAddress() && !TLSHelper::looksLikeTLS(bytes));
}
return true;
} else {
return sslPolicy != SSLPolicy::DISABLED && TLSHelper::looksLikeTLS(bytes);
}
}
std::optional<ThriftParametersContext>
Cpp2Worker::getThriftParametersContext() {
auto thriftConfigBase =
folly::get_ptr(accConfig_.customConfigMap, "thrift_tls_config");
if (!thriftConfigBase) {
return std::nullopt;
}
assert(static_cast<ThriftTlsConfig*>((*thriftConfigBase).get()));
auto thriftConfig = static_cast<ThriftTlsConfig*>((*thriftConfigBase).get());
if (!thriftConfig->enableThriftParamsNegotiation) {
return std::nullopt;
}
auto thriftParametersContext = ThriftParametersContext();
thriftParametersContext.setUseStopTLS(
thriftConfig->enableStopTLS || **ThriftServer::enableStopTLS());
return thriftParametersContext;
}
wangle::AcceptorHandshakeHelper::UniquePtr Cpp2Worker::getHelper(
const std::vector<uint8_t>& bytes,
const folly::SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime,
wangle::TransportInfo& ti) {
if (!shouldPerformSSL(bytes, clientAddr)) {
return wangle::AcceptorHandshakeHelper::UniquePtr(
new wangle::UnencryptedAcceptorHandshakeHelper());
}
return createSSLHelper(bytes, clientAddr, acceptTime, ti);
}
void Cpp2Worker::requestStop() {
getEventBase()->runInEventBaseThreadAndWait([&] {
if (isStopping()) {
return;
}
cancelQueuedRequests();
stopping_.store(true, std::memory_order_relaxed);
if (activeRequests_ == 0) {
stopBaton_.post();
}
});
}
bool Cpp2Worker::waitForStop(std::chrono::steady_clock::time_point deadline) {
if (!stopBaton_.try_wait_until(deadline)) {
LOG(ERROR) << "Failed to join outstanding requests.";
return false;
}
return true;
}
void Cpp2Worker::cancelQueuedRequests() {
auto eb = getEventBase();
eb->dcheckIsInEventBaseThread();
for (auto& stub : requestsRegistry_->getActive()) {
if (stub.stateMachine_.isActive() &&
stub.stateMachine_.tryStopProcessing()) {
stub.req_->sendQueueTimeoutResponse();
}
}
}
Cpp2Worker::ActiveRequestsGuard Cpp2Worker::getActiveRequestsGuard() {
DCHECK(!isStopping() || activeRequests_);
++activeRequests_;
return Cpp2Worker::ActiveRequestsGuard(this);
}
Cpp2Worker::PerServiceMetadata::FindMethodResult
Cpp2Worker::PerServiceMetadata::findMethod(std::string_view methodName) const {
if (const auto* map =
std::get_if<AsyncProcessorFactory::MethodMetadataMap>(&methods_)) {
if (auto* m = folly::get_ptr(*map, methodName)) {
DCHECK(m->get());
return MetadataFound{**m};
}
return MetadataNotFound{};
}
if (const auto* wildcard =
std::get_if<AsyncProcessorFactory::WildcardMethodMetadataMap>(
&methods_)) {
if (auto* m = folly::get_ptr(wildcard->knownMethods, methodName)) {
DCHECK(m->get());
return MetadataFound{**m};
}
return MetadataFound{AsyncProcessorFactory::kWildcardMethodMetadata};
}
if (std::holds_alternative<AsyncProcessorFactory::MetadataNotImplemented>(
methods_)) {
return MetadataNotImplemented{};
}
LOG(FATAL) << "Invalid CreateMethodMetadataResult from service";
folly::assume_unreachable();
}
std::shared_ptr<folly::RequestContext>
Cpp2Worker::PerServiceMetadata::getBaseContextForRequest(
const Cpp2Worker::PerServiceMetadata::FindMethodResult& findMethodResult)
const {
if (const auto* found =
std::get_if<PerServiceMetadata::MetadataFound>(&findMethodResult)) {
return processorFactory_.getBaseContextForRequest(found->metadata);
}
return nullptr;
}
} // namespace thrift
} // namespace apache