thrift/lib/cpp2/server/ThriftProcessor.cpp (149 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrift/lib/cpp2/server/ThriftProcessor.h>
#include <string>
#include <folly/Overload.h>
#include <fmt/core.h>
#include <glog/logging.h>
#include <thrift/lib/cpp/transport/THeader.h>
#include <thrift/lib/cpp2/async/AsyncProcessorHelper.h>
#include <thrift/lib/cpp2/async/ResponseChannel.h>
#include <thrift/lib/cpp2/server/Cpp2ConnContext.h>
#include <thrift/lib/cpp2/server/Cpp2Worker.h>
#include <thrift/lib/cpp2/server/ThriftServer.h>
#include <thrift/lib/cpp2/transport/core/ThriftRequest.h>
#include <thrift/lib/cpp2/util/Checksum.h>
namespace apache {
namespace thrift {
ThriftProcessor::ThriftProcessor(ThriftServer& server) : server_(server) {}
void ThriftProcessor::onThriftRequest(
RequestRpcMetadata&& metadata,
std::unique_ptr<IOBuf> payload,
std::shared_ptr<ThriftChannelIf> channel,
std::unique_ptr<Cpp2ConnContext> connContext) noexcept {
DCHECK(payload);
DCHECK(channel);
auto& processorFactory = server_.getDecoratedProcessorFactory();
if (processor_ == nullptr) {
processor_ = processorFactory.getProcessor();
}
auto worker = connContext->getWorker();
worker->getEventBase()->dcheckIsInEventBaseThread();
bool invalidMetadata =
!(metadata.protocol_ref() && metadata.name_ref() && metadata.kind_ref());
bool invalidChecksum = metadata.crc32c_ref() &&
*metadata.crc32c_ref() != apache::thrift::checksum::crc32c(*payload);
auto request = std::make_unique<ThriftRequest>(
server_, channel, std::move(metadata), std::move(connContext));
auto* evb = channel->getEventBase();
if (UNLIKELY(invalidMetadata)) {
LOG(ERROR) << "Invalid metadata object";
evb->runInEventBaseThread([request = std::move(request)]() {
request->sendErrorWrapped(
folly::make_exception_wrapper<TApplicationException>(
TApplicationException::UNSUPPORTED_CLIENT_TYPE,
"invalid metadata object"),
"corrupted metadata");
});
return;
}
if (UNLIKELY(invalidChecksum)) {
LOG(ERROR) << "Invalid checksum";
evb->runInEventBaseThread([request = std::move(request)]() {
request->sendErrorWrapped(
folly::make_exception_wrapper<TApplicationException>(
TApplicationException::CHECKSUM_MISMATCH, "checksum mismatch"),
"corrupted request");
});
return;
}
const auto& serviceMetadata = worker->getMetadataForService(processorFactory);
using PerServiceMetadata = Cpp2Worker::PerServiceMetadata;
const PerServiceMetadata::FindMethodResult methodMetadataResult =
serviceMetadata.findMethod(request->getMethodName());
auto baseReqCtx =
serviceMetadata.getBaseContextForRequest(methodMetadataResult);
auto reqCtx = baseReqCtx ? folly::RequestContext::copyAsChild(*baseReqCtx)
: std::make_shared<folly::RequestContext>();
folly::RequestContextScopeGuard rctx(reqCtx);
auto protoId = request->getProtoId();
auto reqContext = request->getRequestContext();
folly::variant_match(
methodMetadataResult,
[&](PerServiceMetadata::MetadataNotImplemented) {
// The AsyncProcessorFactory does not implement createMethodMetadata
// so we need to fallback to processSerializedCompressedRequest.
processor_->processSerializedCompressedRequest(
std::move(request),
SerializedCompressedRequest(std::move(payload)),
protoId,
reqContext,
evb,
server_.getThreadManager().get());
},
[&](PerServiceMetadata::MetadataNotFound) {
std::string_view methodName = request->getMethodName();
AsyncProcessorHelper::sendUnknownMethodError(
std::move(request), methodName);
},
[&](const PerServiceMetadata::MetadataFound& found) {
if (!server_.resourcePoolSet().empty()) {
// We need to process this using request pools
const ServiceRequestInfo* serviceRequestInfo{nullptr};
if (auto requestInfo = processorFactory.getServiceRequestInfoMap()) {
serviceRequestInfo =
&requestInfo->get().at(request->getMethodName());
}
ServerRequest serverRequest(
std::move(request),
SerializedCompressedRequest(std::move(payload)),
evb,
reqContext,
protoId,
folly::RequestContext::saveContext(),
processor_.get(),
&found.metadata,
serviceRequestInfo);
auto poolResult = AsyncProcessorHelper::selectResourcePool(
serverRequest, found.metadata);
if (auto* reject = std::get_if<ServerRequestRejection>(&poolResult)) {
auto errorCode = kAppOverloadedErrorCode;
if (reject->applicationException().getType() ==
TApplicationException::UNKNOWN_METHOD) {
errorCode = kMethodUnknownErrorCode;
}
serverRequest.request()->sendErrorWrapped(
folly::exception_wrapper(
folly::in_place, std::move(*reject).applicationException()),
errorCode);
return;
}
auto resourcePoolHandle =
std::get_if<std::reference_wrapper<const ResourcePoolHandle>>(
&poolResult);
DCHECK(
server_.resourcePoolSet().hasResourcePool(*resourcePoolHandle));
auto& resourcePool =
server_.resourcePoolSet().resourcePool(*resourcePoolHandle);
apache::thrift::detail::ServerRequestHelper::setExecutor(
serverRequest, resourcePool.executor().value_or(nullptr));
auto result = resourcePool.accept(std::move(serverRequest));
if (result) {
auto errorCode = kQueueOverloadedErrorCode;
serverRequest.request()->sendErrorWrapped(
folly::exception_wrapper(
folly::in_place,
std::move(std::move(result).value())
.applicationException()),
errorCode);
return;
}
} else {
processor_->processSerializedCompressedRequestWithMetadata(
std::move(request),
SerializedCompressedRequest(std::move(payload)),
found.metadata,
protoId,
reqContext,
evb,
server_.getThreadManager().get());
}
});
}
} // namespace thrift
} // namespace apache