thrift/lib/cpp2/async/AsyncProcessor.cpp (584 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrift/lib/cpp2/async/AsyncProcessor.h>
#include <folly/io/async/EventBaseAtomicNotificationQueue.h>
#include <thrift/lib/cpp2/async/ReplyInfo.h>
namespace apache {
namespace thrift {
/* static */ const AsyncProcessorFactory::WildcardMethodMetadata
AsyncProcessorFactory::kWildcardMethodMetadata;
constexpr std::chrono::seconds ServerInterface::BlockingThreadManager::kTimeout;
thread_local RequestParams ServerInterface::requestParams_;
EventTask::~EventTask() {
expired();
}
void EventTask::expired() {
// only expire req_ once
if (!req_) {
return;
}
failWith(
TApplicationException{"Task expired without processing"},
kTaskExpiredErrorCode);
}
void EventTask::failWith(folly::exception_wrapper ex, std::string exCode) {
auto cleanUp = [oneway = oneway_,
req = std::move(req_),
ex = std::move(ex),
exCode = std::move(exCode)]() mutable {
// if oneway, skip sending back anything
if (oneway) {
return;
}
req->sendErrorWrapped(std::move(ex), std::move(exCode));
};
if (eb_->inRunningEventBaseThread()) {
cleanUp();
} else {
eb_->runInEventBaseThread(std::move(cleanUp));
}
}
void EventTask::setTile(TilePtr&& tile) {
ctx_->setTile(std::move(tile));
}
std::pair<AsyncProcessor*, const AsyncProcessorFactory::MethodMetadata*>
AsyncProcessorSet::getRequestsProcessor(
const AsyncProcessorFactory::MethodMetadata&) {
LOG(FATAL) << "Unimplmented getRequestsProcessor called";
}
char const* AsyncProcessor::getServiceName() {
return "NoServiceNameSet";
}
void AsyncProcessor::terminateInteraction(
int64_t, Cpp2ConnContext&, folly::EventBase&) noexcept {
LOG(DFATAL) << "This processor doesn't support interactions";
}
void AsyncProcessor::destroyAllInteractions(
Cpp2ConnContext&, folly::EventBase&) noexcept {}
void AsyncProcessor::processSerializedCompressedRequest(
ResponseChannelRequest::UniquePtr req,
SerializedCompressedRequest&& serializedRequest,
protocol::PROTOCOL_TYPES prot_type,
Cpp2RequestContext* context,
folly::EventBase* eb,
concurrency::ThreadManager* tm) {
processSerializedRequest(
std::move(req),
std::move(serializedRequest).uncompress(),
prot_type,
context,
eb,
tm);
}
void AsyncProcessor::processSerializedCompressedRequestWithMetadata(
ResponseChannelRequest::UniquePtr,
SerializedCompressedRequest&&,
const MethodMetadata&,
protocol::PROTOCOL_TYPES,
Cpp2RequestContext*,
folly::EventBase*,
concurrency::ThreadManager*) {
LOG(FATAL)
<< "processSerializedCompressedRequestWithMetadata was called because "
"AsyncProcessorFactory::createMethodMetadata (from the provided service) "
"opted in to use the MethodMetadata-based method resolution API. "
"Therefore, this method must be overridden alongside processSerializedRequest.";
}
void AsyncProcessor::executeRequest(
ServerRequest&&, const AsyncProcessorFactory::MethodMetadata&) {
LOG(FATAL) << "Unimplemented executeRequest called";
}
bool GeneratedAsyncProcessor::createInteraction(
const ResponseChannelRequest::UniquePtr& req,
int64_t id,
std::string&& name,
Cpp2RequestContext& ctx,
concurrency::ThreadManager* tm,
folly::EventBase& eb,
ServerInterface* si,
bool isFactoryFunction) {
eb.dcheckIsInEventBaseThread();
auto nullthrows = [](std::unique_ptr<Tile> tile) {
if (!tile) {
DLOG(FATAL) << "Nullptr returned from interaction constructor";
throw std::runtime_error("Nullptr returned from interaction constructor");
}
return tile;
};
auto& conn = *ctx.getConnectionContext();
// In the eb model with old-style constructor we create the interaction
// inline.
if (!tm && !isFactoryFunction) {
si->setEventBase(&eb);
si->setRequestContext(&ctx);
auto tile = folly::makeTryWith(
[&] { return nullthrows(createInteractionImpl(name)); });
if (tile.hasException()) {
req->sendErrorWrapped(
folly::make_exception_wrapper<TApplicationException>(
"Interaction constructor failed with " +
tile.exception().what().toStdString()),
kInteractionConstructorErrorErrorCode);
return true; // Not a duplicate; caller will see missing tile.
}
return conn.addTile(id, {tile->release(), &eb});
}
// Otherwise we use a promise.
auto promisePtr =
new TilePromise(isFactoryFunction); // freed by RefGuard on next line
if (!conn.addTile(id, {promisePtr, &eb})) {
return false;
}
// Old-style constructor + tm : schedule constructor and return
if (!isFactoryFunction) {
tm->add([=, &eb, &ctx, name = std::move(name), &conn] {
si->setEventBase(&eb);
si->setThreadManager(tm);
si->setRequestContext(&ctx);
std::exception_ptr ex;
try {
auto tilePtr = nullthrows(createInteractionImpl(name));
eb.add([=, &conn, &eb, t = std::move(tilePtr)]() mutable {
TilePtr tile{t.release(), &eb};
promisePtr->fulfill(*tile, tm, eb);
conn.tryReplaceTile(id, std::move(tile));
});
return;
} catch (...) {
ex = std::current_exception();
}
DCHECK(ex);
eb.add([promisePtr, ex = std::move(ex)]() {
promisePtr->failWith(
folly::make_exception_wrapper<TApplicationException>(
folly::to<std::string>(
"Interaction constructor failed with ",
folly::exceptionStr(ex))),
kInteractionConstructorErrorErrorCode);
});
});
return true;
}
// Factory function: the handler method will fulfill the promise
return true;
}
std::unique_ptr<Tile> GeneratedAsyncProcessor::createInteractionImpl(
const std::string&) {
return nullptr;
}
void GeneratedAsyncProcessor::terminateInteraction(
int64_t id, Cpp2ConnContext& conn, folly::EventBase& eb) noexcept {
eb.dcheckIsInEventBaseThread();
if (auto tile = conn.removeTile(id)) {
Tile::__fbthrift_onTermination(std::move(tile), eb);
}
}
void GeneratedAsyncProcessor::destroyAllInteractions(
Cpp2ConnContext& conn, folly::EventBase& eb) noexcept {
eb.dcheckIsInEventBaseThread();
if (conn.tiles_.empty()) {
return;
}
std::vector<int64_t> ids;
ids.reserve(conn.tiles_.size());
for (auto& [id, tile] : conn.tiles_) {
ids.push_back(id);
}
for (auto id : ids) {
conn.removeTile(id);
}
}
bool GeneratedAsyncProcessor::validateRpcKind(
const ResponseChannelRequest::UniquePtr& req, RpcKind kind) {
switch (kind) {
case RpcKind::SINGLE_REQUEST_NO_RESPONSE:
switch (req->rpcKind()) {
case RpcKind::SINGLE_REQUEST_NO_RESPONSE:
return true;
case RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE:
req->sendReply(ResponsePayload{});
return true;
default:
break;
}
break;
case RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE:
switch (req->rpcKind()) {
case RpcKind::SINGLE_REQUEST_NO_RESPONSE:
case RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE:
return true;
default:
break;
}
break;
default:
if (kind == req->rpcKind()) {
return true;
}
}
if (req->rpcKind() != RpcKind::SINGLE_REQUEST_NO_RESPONSE) {
req->sendErrorWrapped(
folly::make_exception_wrapper<TApplicationException>(
TApplicationException::TApplicationExceptionType::UNKNOWN_METHOD,
"Function kind mismatch"),
kRequestTypeDoesntMatchServiceFunctionType);
}
return false;
}
bool GeneratedAsyncProcessor::setUpRequestProcessing(
const ResponseChannelRequest::UniquePtr& req,
Cpp2RequestContext* ctx,
folly::EventBase* eb,
concurrency::ThreadManager* tm,
RpcKind kind,
ServerInterface* si,
folly::StringPiece interaction,
bool isInteractionFactoryFunction) {
if (!validateRpcKind(req, kind)) {
return false;
}
bool interactionMetadataValid;
if (auto interactionId = ctx->getInteractionId()) {
if (auto interactionCreate = ctx->getInteractionCreate()) {
if (*interactionCreate->interactionName_ref() != interaction) {
interactionMetadataValid = false;
} else if (!createInteraction(
req,
interactionId,
std::move(*interactionCreate->interactionName_ref()).str(),
*ctx,
tm,
*eb,
si,
isInteractionFactoryFunction)) {
// Duplicate id is a contract violation so close the connection.
// Terminate this interaction first so queued requests can't use it
// (which could result in UB).
terminateInteraction(interactionId, *ctx->getConnectionContext(), *eb);
req->sendErrorWrapped(
TApplicationException(
"Attempting to create interaction with duplicate id. Failing all requests in that interaction."),
kConnectionClosingErrorCode);
return false;
} else {
interactionMetadataValid = true;
}
} else {
interactionMetadataValid = !interaction.empty();
}
if (interactionMetadataValid && !tm) {
try {
// This is otherwise done while constructing InteractionEventTask.
auto& tile = ctx->getConnectionContext()->getTile(interactionId);
ctx->setTile({&tile, eb});
} catch (const std::out_of_range&) {
req->sendErrorWrapped(
TApplicationException(
"Invalid interaction id " + std::to_string(interactionId)),
kInteractionIdUnknownErrorCode);
return false;
}
}
} else {
interactionMetadataValid = interaction.empty();
}
if (!interactionMetadataValid) {
req->sendErrorWrapped(
folly::make_exception_wrapper<TApplicationException>(
TApplicationException::TApplicationExceptionType::UNKNOWN_METHOD,
"Interaction and method do not match"),
kMethodUnknownErrorCode);
return false;
}
return true;
}
concurrency::PRIORITY ServerInterface::getRequestPriority(
Cpp2RequestContext* ctx, concurrency::PRIORITY prio) {
concurrency::PRIORITY callPriority = ctx->getCallPriority();
return callPriority == concurrency::N_PRIORITIES ? prio : callPriority;
}
void ServerInterface::setEventBase(folly::EventBase* eb) {
folly::RequestEventBase::set(eb);
requestParams_.eventBase_ = eb;
}
void ServerInterface::BlockingThreadManager::add(folly::Func f) {
try {
if (threadManagerKa_) {
std::shared_ptr<concurrency::Runnable> task =
concurrency::FunctionRunner::create(std::move(f));
threadManagerKa_->add(
std::move(task),
std::chrono::milliseconds(kTimeout).count() /* deprecated */,
0,
false);
} else {
executorKa_->add(std::move(f));
}
return;
} catch (...) {
LOG(FATAL) << "Failed to schedule a task within timeout: "
<< folly::exceptionStr(std::current_exception());
}
}
bool ServerInterface::BlockingThreadManager::keepAliveAcquire() noexcept {
auto keepAliveCount = keepAliveCount_.fetch_add(1, std::memory_order_relaxed);
// We should never increment from 0
DCHECK(keepAliveCount > 0);
return true;
}
void ServerInterface::BlockingThreadManager::keepAliveRelease() noexcept {
auto keepAliveCount = keepAliveCount_.fetch_sub(1, std::memory_order_acq_rel);
DCHECK(keepAliveCount >= 1);
if (keepAliveCount == 1) {
delete this;
}
}
folly::Executor::KeepAlive<> ServerInterface::getInternalKeepAlive() {
if (getThreadManager()) {
return getThreadManager()->getKeepAlive(
getRequestContext()->getRequestExecutionScope(),
apache::thrift::concurrency::ThreadManager::Source::INTERNAL);
} else {
return folly::Executor::getKeepAliveToken(getHandlerExecutor());
}
}
folly::Executor::KeepAlive<> HandlerCallbackBase::getInternalKeepAlive() {
if (getThreadManager()) {
return getThreadManager()->getKeepAlive(
getRequestContext()->getRequestExecutionScope(),
apache::thrift::concurrency::ThreadManager::Source::INTERNAL);
} else {
return folly::Executor::getKeepAliveToken(getHandlerExecutor());
}
}
HandlerCallbackBase::~HandlerCallbackBase() {
maybeNotifyComplete();
// req must be deleted in the eb
if (req_) {
if (req_->isActive() && ewp_) {
exception(TApplicationException(
TApplicationException::INTERNAL_ERROR,
"apache::thrift::HandlerCallback not completed"));
return;
}
assert(eb_ != nullptr);
releaseRequest(std::move(req_), eb_, std::move(interaction_));
}
}
void HandlerCallbackBase::releaseRequest(
ResponseChannelRequest::UniquePtr request,
folly::EventBase* eb,
TilePtr interaction) {
DCHECK(request);
DCHECK(eb != nullptr);
if (!eb->inRunningEventBaseThread()) {
eb->runInEventBaseThread(
[req = std::move(request), interaction = std::move(interaction)] {});
}
}
folly::EventBase* HandlerCallbackBase::getEventBase() {
assert(eb_ != nullptr);
return eb_;
}
concurrency::ThreadManager* HandlerCallbackBase::getThreadManager() {
return tm_;
}
folly::Executor* HandlerCallbackBase::getHandlerExecutor() {
if (executor_ == nullptr) {
return tm_;
}
return executor_;
}
folly::Optional<uint32_t> HandlerCallbackBase::checksumIfNeeded(
LegacySerializedResponse& response) {
folly::Optional<uint32_t> crc32c;
if (req_->isReplyChecksumNeeded() && response.buffer &&
!response.buffer->empty()) {
LOG(ERROR) << "Checksum calculation disabled";
}
return crc32c;
}
folly::Optional<uint32_t> HandlerCallbackBase::checksumIfNeeded(
SerializedResponse& response) {
folly::Optional<uint32_t> crc32c;
if (req_->isReplyChecksumNeeded() && response.buffer &&
!response.buffer->empty()) {
LOG(ERROR) << "Checksum calculation disabled";
}
return crc32c;
}
ResponsePayload HandlerCallbackBase::transform(ResponsePayload&& payload) {
// Do any compression or other transforms in this thread, the same thread
// that serialization happens on.
payload.transform(reqCtx_->getHeader()->getWriteTransforms());
return std::move(payload);
}
void HandlerCallbackBase::doExceptionWrapped(folly::exception_wrapper ew) {
if (req_ == nullptr) {
LOG(ERROR) << ew.what();
} else {
callExceptionInEventBaseThread(ewp_, ew);
}
}
void HandlerCallbackBase::sendReply(SerializedResponse response) {
folly::Optional<uint32_t> crc32c = checksumIfNeeded(response);
auto payload = std::move(response).extractPayload(
req_->includeEnvelope(),
reqCtx_->getHeader()->getProtocolId(),
protoSeqId_,
MessageType::T_REPLY,
reqCtx_->getMethodName());
payload = transform(std::move(payload));
if (getEventBase()->isInEventBaseThread()) {
QueueReplyInfo(
std::move(req_), std::move(payload), crc32c)(*getEventBase());
} else {
putMessageInReplyQueue(
std::in_place_type_t<QueueReplyInfo>(),
std::move(req_),
std::move(payload),
crc32c);
}
}
void HandlerCallbackBase::sendReply(
ResponseAndServerStreamFactory&& responseAndStream) {
folly::Optional<uint32_t> crc32c =
checksumIfNeeded(responseAndStream.response);
auto payload = std::move(responseAndStream.response)
.extractPayload(
req_->includeEnvelope(),
reqCtx_->getHeader()->getProtocolId(),
protoSeqId_,
MessageType::T_REPLY,
reqCtx_->getMethodName());
payload = transform(std::move(payload));
auto& stream = responseAndStream.stream;
stream.setInteraction(std::move(interaction_));
if (getEventBase()->isInEventBaseThread()) {
StreamReplyInfo(
std::move(req_), std::move(stream), std::move(payload), crc32c)(
*getEventBase());
} else {
putMessageInReplyQueue(
std::in_place_type_t<StreamReplyInfo>(),
std::move(req_),
std::move(stream),
std::move(payload),
crc32c);
}
}
void HandlerCallbackBase::sendReply(
FOLLY_MAYBE_UNUSED std::pair<
SerializedResponse,
apache::thrift::detail::SinkConsumerImpl>&& responseAndSinkConsumer) {
#if FOLLY_HAS_COROUTINES
folly::Optional<uint32_t> crc32c =
checksumIfNeeded(responseAndSinkConsumer.first);
auto payload = std::move(responseAndSinkConsumer.first)
.extractPayload(
req_->includeEnvelope(),
reqCtx_->getHeader()->getProtocolId(),
protoSeqId_,
MessageType::T_REPLY,
reqCtx_->getMethodName());
payload = transform(std::move(payload));
auto& sinkConsumer = responseAndSinkConsumer.second;
sinkConsumer.interaction = std::move(interaction_);
if (getEventBase()->isInEventBaseThread()) {
SinkConsumerReplyInfo(
std::move(req_), std::move(sinkConsumer), std::move(payload), crc32c)(
*getEventBase());
} else {
putMessageInReplyQueue(
std::in_place_type_t<SinkConsumerReplyInfo>(),
std::move(req_),
std::move(sinkConsumer),
std::move(payload),
crc32c);
}
#else
std::terminate();
#endif
}
bool HandlerCallbackBase::fulfillTilePromise(std::unique_ptr<Tile> ptr) {
if (!ptr) {
DLOG(FATAL) << "Nullptr interaction yielded from handler";
exception(TApplicationException(
TApplicationException::MISSING_RESULT,
"Nullptr interaction yielded from handler"));
return false;
}
auto fn = [ctx = reqCtx_,
interaction = std::move(interaction_),
ptr = std::move(ptr),
tm = tm_,
eb = eb_]() mutable {
TilePtr tile{ptr.release(), eb};
DCHECK(dynamic_cast<TilePromise*>(interaction.get()));
static_cast<TilePromise&>(*interaction).fulfill(*tile, tm, *eb);
ctx->getConnectionContext()->tryReplaceTile(
ctx->getInteractionId(), std::move(tile));
};
eb_->runImmediatelyOrRunInEventBaseThread(std::move(fn));
return true;
}
void HandlerCallbackBase::breakTilePromise() {
auto fn = [interaction = std::move(interaction_)]() mutable {
DCHECK(dynamic_cast<TilePromise*>(interaction.get()));
static_cast<TilePromise&>(*interaction)
.failWith(
folly::make_exception_wrapper<TApplicationException>(
"Interaction constructor failed"),
kInteractionConstructorErrorErrorCode);
};
eb_->runImmediatelyOrRunInEventBaseThread(std::move(fn));
}
HandlerCallback<void>::HandlerCallback(
ResponseChannelRequest::UniquePtr req,
std::unique_ptr<ContextStack> ctx,
cob_ptr cp,
exnw_ptr ewp,
int32_t protoSeqId,
folly::EventBase* eb,
concurrency::ThreadManager* tm,
Cpp2RequestContext* reqCtx,
TilePtr&& interaction)
: HandlerCallbackBase(
std::move(req),
std::move(ctx),
ewp,
eb,
tm,
reqCtx,
std::move(interaction)),
cp_(cp) {
this->protoSeqId_ = protoSeqId;
}
HandlerCallback<void>::HandlerCallback(
ResponseChannelRequest::UniquePtr req,
std::unique_ptr<ContextStack> ctx,
cob_ptr cp,
exnw_ptr ewp,
int32_t protoSeqId,
folly::EventBase* eb,
folly::Executor* executor,
Cpp2RequestContext* reqCtx,
RequestPileInterface* notifyRequestPile,
RequestPileInterface::UserData notifyRequestPileUserData,
ConcurrencyControllerInterface* notifyConcurrencyController,
ConcurrencyControllerInterface::UserData notifyConcurrencyControllerData,
TilePtr&& interaction)
: HandlerCallbackBase(
std::move(req),
std::move(ctx),
ewp,
eb,
executor,
reqCtx,
notifyRequestPile,
notifyRequestPileUserData,
notifyConcurrencyController,
notifyConcurrencyControllerData,
std::move(interaction)),
cp_(cp) {
this->protoSeqId_ = protoSeqId;
}
void HandlerCallback<void>::complete(folly::Try<folly::Unit>&& r) {
maybeNotifyComplete();
if (r.hasException()) {
exception(std::move(r.exception()));
} else {
done();
}
}
void HandlerCallback<void>::doDone() {
assert(cp_ != nullptr);
auto queue = cp_(this->ctx_.get());
this->ctx_.reset();
sendReply(std::move(queue));
}
} // namespace thrift
} // namespace apache