thrift/lib/cpp2/transport/rocket/server/RocketThriftRequests.cpp (679 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrift/lib/cpp2/transport/rocket/server/RocketThriftRequests.h>
#include <functional>
#include <memory>
#include <utility>
#include <folly/ExceptionWrapper.h>
#include <folly/Function.h>
#include <folly/io/IOBuf.h>
#include <folly/io/IOBufQueue.h>
#include <thrift/lib/cpp/protocol/TBase64Utils.h>
#include <thrift/lib/cpp2/SerializationSwitch.h>
#include <thrift/lib/cpp2/async/ServerSinkBridge.h>
#include <thrift/lib/cpp2/async/StreamCallbacks.h>
#include <thrift/lib/cpp2/protocol/CompactProtocol.h>
#include <thrift/lib/cpp2/server/LoggingEvent.h>
#include <thrift/lib/cpp2/transport/core/RpcMetadataUtil.h>
#include <thrift/lib/cpp2/transport/rocket/PayloadUtils.h>
#include <thrift/lib/cpp2/transport/rocket/framing/Flags.h>
#include <thrift/lib/cpp2/transport/rocket/server/RocketServerConnection.h>
#include <thrift/lib/cpp2/transport/rocket/server/RocketSinkClientCallback.h>
#include <thrift/lib/cpp2/transport/rocket/server/RocketStreamClientCallback.h>
#include <thrift/lib/thrift/gen-cpp2/RpcMetadata_constants.h>
#include <thrift/lib/thrift/gen-cpp2/RpcMetadata_types.h>
namespace apache {
namespace thrift {
namespace rocket {
namespace {
RocketException makeResponseRpcError(
ResponseRpcErrorCode errorCode,
folly::StringPiece message,
const ResponseRpcMetadata& metadata) {
ResponseRpcError responseRpcError;
responseRpcError.name_utf8_ref() =
apache::thrift::TEnumTraits<ResponseRpcErrorCode>::findName(errorCode);
responseRpcError.what_utf8_ref() = message.str();
responseRpcError.code_ref() = errorCode;
auto category = [&] {
switch (errorCode) {
case ResponseRpcErrorCode::REQUEST_PARSING_FAILURE:
case ResponseRpcErrorCode::WRONG_RPC_KIND:
case ResponseRpcErrorCode::UNKNOWN_METHOD:
case ResponseRpcErrorCode::CHECKSUM_MISMATCH:
case ResponseRpcErrorCode::UNKNOWN_INTERACTION_ID:
case ResponseRpcErrorCode::UNIMPLEMENTED_METHOD:
return ResponseRpcErrorCategory::INVALID_REQUEST;
case ResponseRpcErrorCode::OVERLOAD:
case ResponseRpcErrorCode::QUEUE_OVERLOADED:
case ResponseRpcErrorCode::QUEUE_TIMEOUT:
case ResponseRpcErrorCode::APP_OVERLOAD:
return ResponseRpcErrorCategory::LOADSHEDDING;
case ResponseRpcErrorCode::SHUTDOWN:
return ResponseRpcErrorCategory::SHUTDOWN;
default:
return ResponseRpcErrorCategory::INTERNAL_ERROR;
}
}();
responseRpcError.category_ref() = category;
if (auto loadRef = metadata.load_ref()) {
responseRpcError.load_ref() = *loadRef;
}
auto rocketCategory = [&] {
switch (category) {
case ResponseRpcErrorCategory::INVALID_REQUEST:
return rocket::ErrorCode::INVALID;
case ResponseRpcErrorCategory::LOADSHEDDING:
case ResponseRpcErrorCategory::SHUTDOWN:
return rocket::ErrorCode::REJECTED;
default:
return rocket::ErrorCode::CANCELED;
}
}();
return RocketException(rocketCategory, packCompact(responseRpcError));
}
void preprocessProxiedExceptionHeaders(
ResponseRpcMetadata& metadata, int32_t version) {
DCHECK_GE(version, 4);
auto otherMetadataRef = metadata.otherMetadata_ref();
if (!otherMetadataRef) {
return;
}
auto& otherMetadata = *otherMetadataRef;
if (auto puexPtr = folly::get_ptr(otherMetadata, "puex")) {
metadata.proxiedPayloadMetadata_ref() = ProxiedPayloadMetadata();
otherMetadata.insert({"uex", std::move(*puexPtr)});
otherMetadata.erase("puex");
if (auto puexwPtr = folly::get_ptr(otherMetadata, "puexw")) {
otherMetadata.insert({"uexw", std::move(*puexwPtr)});
otherMetadata.erase("puexw");
}
}
if (auto pexPtr = folly::get_ptr(otherMetadata, "pex")) {
metadata.proxiedPayloadMetadata_ref() = ProxiedPayloadMetadata();
otherMetadata.insert({"ex", std::move(*pexPtr)});
otherMetadata.erase("pex");
}
if (auto proxiedErrorPtr =
folly::get_ptr(otherMetadata, "servicerouter:sr_error")) {
metadata.proxiedPayloadMetadata_ref() = ProxiedPayloadMetadata();
otherMetadata.insert(
{"servicerouter:sr_internal_error", std::move(*proxiedErrorPtr)});
otherMetadata.erase("servicerouter:sr_error");
}
}
template <typename ProtocolReader>
FOLLY_NODISCARD folly::exception_wrapper processFirstResponseHelper(
ResponseRpcMetadata& metadata,
std::unique_ptr<folly::IOBuf>& payload,
int32_t version) noexcept {
try {
std::string methodNameIgnore;
MessageType mtype;
int32_t seqIdIgnore;
ProtocolReader reader;
reader.setInput(payload.get());
reader.readMessageBegin(methodNameIgnore, mtype, seqIdIgnore);
switch (mtype) {
case MessageType::T_REPLY: {
auto prefixSize = reader.getCursorPosition();
protocol::TType ftype;
int16_t fid;
reader.readStructBegin(methodNameIgnore);
reader.readFieldBegin(methodNameIgnore, ftype, fid);
while (payload->length() < prefixSize) {
prefixSize -= payload->length();
payload = payload->pop();
}
payload->trimStart(prefixSize);
PayloadMetadata payloadMetadata;
if (fid == 0) {
payloadMetadata.set_responseMetadata(PayloadResponseMetadata());
} else {
preprocessProxiedExceptionHeaders(metadata, version);
PayloadExceptionMetadataBase exceptionMetadataBase;
PayloadDeclaredExceptionMetadata declaredExceptionMetadata;
if (auto otherMetadataRef = metadata.otherMetadata_ref()) {
// defined in sync with
// thrift/lib/cpp2/transport/core/RpcMetadataUtil.h
// Setting user exception name and content
static const auto uex =
std::string(apache::thrift::detail::kHeaderUex);
if (auto uexPtr = folly::get_ptr(*otherMetadataRef, uex)) {
exceptionMetadataBase.name_utf8_ref() = *uexPtr;
otherMetadataRef->erase(uex);
}
static const auto uexw =
std::string(apache::thrift::detail::kHeaderUexw);
if (auto uexwPtr = folly::get_ptr(*otherMetadataRef, uexw)) {
exceptionMetadataBase.what_utf8_ref() = *uexwPtr;
otherMetadataRef->erase(uexw);
}
// Setting user declared exception classification
static const auto exMeta =
std::string(apache::thrift::detail::kHeaderExMeta);
if (auto metaPtr = folly::get_ptr(*otherMetadataRef, exMeta)) {
ErrorClassification errorClassification =
apache::thrift::detail::deserializeErrorClassification(
*metaPtr);
declaredExceptionMetadata.errorClassification_ref() =
std::move(errorClassification);
}
}
PayloadExceptionMetadata exceptionMetadata;
exceptionMetadata.set_declaredException(
std::move(declaredExceptionMetadata));
exceptionMetadataBase.metadata_ref() = std::move(exceptionMetadata);
payloadMetadata.set_exceptionMetadata(
std::move(exceptionMetadataBase));
}
metadata.payloadMetadata_ref() = std::move(payloadMetadata);
break;
}
case MessageType::T_EXCEPTION: {
DCHECK_GE(version, 2);
preprocessProxiedExceptionHeaders(metadata, version);
TApplicationException ex;
::apache::thrift::detail::deserializeExceptionBody(&reader, &ex);
PayloadExceptionMetadataBase exceptionMetadataBase;
exceptionMetadataBase.what_utf8_ref() = ex.getMessage();
auto otherMetadataRef = metadata.otherMetadata_ref();
DCHECK(
!otherMetadataRef ||
!folly::get_ptr(*otherMetadataRef, "servicerouter:sr_error"));
if (auto proxyErrorPtr = otherMetadataRef
? folly::get_ptr(
*otherMetadataRef, "servicerouter:sr_internal_error")
: nullptr) {
exceptionMetadataBase.name_utf8_ref() = "ProxyException";
PayloadExceptionMetadata exceptionMetadata;
exceptionMetadata.set_proxyException(PayloadProxyExceptionMetadata());
exceptionMetadataBase.metadata_ref() = std::move(exceptionMetadata);
payload = protocol::base64Decode(*proxyErrorPtr);
otherMetadataRef->erase("servicerouter:sr_internal_error");
otherMetadataRef->erase("ex");
} else {
DCHECK_GE(version, 3);
auto exPtr = otherMetadataRef
? folly::get_ptr(*otherMetadataRef, "ex")
: nullptr;
auto uexPtr = otherMetadataRef
? folly::get_ptr(*otherMetadataRef, "uex")
: nullptr;
if (auto errorCode = [&]() -> folly::Optional<ResponseRpcErrorCode> {
if (exPtr) {
if (*exPtr == kQueueOverloadedErrorCode &&
ex.getType() == TApplicationException::LOADSHEDDING) {
return ResponseRpcErrorCode::SHUTDOWN;
}
static const auto& errorCodeMap = *new std::unordered_map<
std::string,
ResponseRpcErrorCode>(
{{kUnknownErrorCode, ResponseRpcErrorCode::UNKNOWN},
{kOverloadedErrorCode, ResponseRpcErrorCode::OVERLOAD},
{kAppOverloadedErrorCode,
ResponseRpcErrorCode::APP_OVERLOAD},
{kTaskExpiredErrorCode,
ResponseRpcErrorCode::TASK_EXPIRED},
{kQueueOverloadedErrorCode,
ResponseRpcErrorCode::QUEUE_OVERLOADED},
{kInjectedFailureErrorCode,
ResponseRpcErrorCode::INJECTED_FAILURE},
{kServerQueueTimeoutErrorCode,
ResponseRpcErrorCode::QUEUE_TIMEOUT},
{kResponseTooBigErrorCode,
ResponseRpcErrorCode::RESPONSE_TOO_BIG},
{kMethodUnknownErrorCode,
ResponseRpcErrorCode::UNKNOWN_METHOD},
{kRequestTypeDoesntMatchServiceFunctionType,
ResponseRpcErrorCode::WRONG_RPC_KIND},
{kInteractionIdUnknownErrorCode,
ResponseRpcErrorCode::UNKNOWN_INTERACTION_ID},
{kInteractionConstructorErrorErrorCode,
ResponseRpcErrorCode::INTERACTION_CONSTRUCTOR_ERROR},
{kRequestParsingErrorCode,
ResponseRpcErrorCode::REQUEST_PARSING_FAILURE},
{kChecksumMismatchErrorCode,
ResponseRpcErrorCode::CHECKSUM_MISMATCH},
{kUnimplementedMethodErrorCode,
ResponseRpcErrorCode::UNIMPLEMENTED_METHOD}});
if (auto errorCode = folly::get_ptr(errorCodeMap, *exPtr)) {
return *errorCode;
}
}
return folly::none;
}()) {
return makeResponseRpcError(*errorCode, ex.getMessage(), metadata);
}
if (uexPtr) {
exceptionMetadataBase.name_utf8_ref() = *uexPtr;
otherMetadataRef->erase("uex");
}
PayloadExceptionMetadata exceptionMetadata;
if (exPtr && *exPtr == kAppClientErrorCode) {
if (version < 8) {
exceptionMetadata.set_DEPRECATED_appClientException(
PayloadAppClientExceptionMetadata());
} else {
PayloadAppUnknownExceptionMetdata aue;
aue.errorClassification_ref().ensure().blame_ref() =
ErrorBlame::CLIENT;
exceptionMetadata.set_appUnknownException(std::move(aue));
}
} else {
if (version < 8) {
exceptionMetadata.set_DEPRECATED_appServerException(
PayloadAppServerExceptionMetadata());
} else {
PayloadAppUnknownExceptionMetdata aue;
aue.errorClassification_ref().ensure().blame_ref() =
ErrorBlame::SERVER;
exceptionMetadata.set_appUnknownException(std::move(aue));
}
}
exceptionMetadataBase.metadata_ref() = std::move(exceptionMetadata);
payload->clear();
if (otherMetadataRef) {
otherMetadataRef->erase("ex");
otherMetadataRef->erase("uexw");
}
}
PayloadMetadata payloadMetadata;
payloadMetadata.set_exceptionMetadata(std::move(exceptionMetadataBase));
metadata.payloadMetadata_ref() = std::move(payloadMetadata);
break;
}
default:
DCHECK_GE(version, 3);
return makeResponseRpcError(
ResponseRpcErrorCode::UNKNOWN, "Invalid message type", metadata);
}
} catch (...) {
DCHECK_GE(version, 3);
return makeResponseRpcError(
ResponseRpcErrorCode::UNKNOWN,
fmt::format(
"Invalid response payload envelope: {}",
folly::exceptionStr(std::current_exception()).toStdString()),
metadata);
}
return {};
}
FOLLY_NODISCARD folly::exception_wrapper processFirstResponse(
ResponseRpcMetadata& metadata,
std::unique_ptr<folly::IOBuf>& payload,
apache::thrift::protocol::PROTOCOL_TYPES protType,
int32_t version,
const folly::Optional<CompressionConfig>& compressionConfig) noexcept {
if (!payload) {
return makeResponseRpcError(
ResponseRpcErrorCode::UNKNOWN,
"serialization failed for response",
metadata);
}
THRIFT_APPLICATION_EVENT(server_write_headers).log([&] {
auto size =
metadata.otherMetadata_ref() ? metadata.otherMetadata_ref()->size() : 0;
std::vector<folly::dynamic> keys;
if (size) {
keys.reserve(size);
for (auto& [k, v] : *metadata.otherMetadata_ref()) {
keys.push_back(k);
}
}
return folly::dynamic::object("size", size) //
("keys", folly::dynamic::array(std::move(keys)));
});
// apply compression if client has specified compression codec
if (compressionConfig.has_value()) {
rocket::detail::setCompressionCodec(
*compressionConfig, metadata, payload->computeChainDataLength());
}
DCHECK_GE(version, 1);
switch (protType) {
case protocol::T_BINARY_PROTOCOL:
return processFirstResponseHelper<BinaryProtocolReader>(
metadata, payload, version);
case protocol::T_COMPACT_PROTOCOL:
return processFirstResponseHelper<CompactProtocolReader>(
metadata, payload, version);
default: {
DCHECK_GE(version, 3);
return makeResponseRpcError(
ResponseRpcErrorCode::UNKNOWN,
"Invalid response payload protocol id",
metadata);
}
}
}
template <typename Error, typename Callback, typename ResponseChannel>
void handleStreamError(
Error&& error, Callback& callback, ResponseChannel* channel) {
error.handle(
[&](RocketException& ex) {
std::exchange(callback, nullptr)->onFirstResponseError(std::move(ex));
},
[&](...) {
channel->sendErrorWrapped(
std::forward<Error>(error), kUnknownErrorCode);
});
}
} // namespace
ThriftServerRequestResponse::ThriftServerRequestResponse(
RequestsRegistry::DebugStub& debugStubToInit,
folly::EventBase& evb,
server::ServerConfigs& serverConfigs,
RequestRpcMetadata&& metadata,
Cpp2ConnContext& connContext,
std::shared_ptr<folly::RequestContext> rctx,
RequestsRegistry& reqRegistry,
rocket::Payload&& debugPayload,
RocketServerFrameContext&& context,
int32_t version)
: ThriftRequestCore(serverConfigs, std::move(metadata), connContext),
evb_(evb),
context_(std::move(context)),
version_(version) {
new (&debugStubToInit) RequestsRegistry::DebugStub(
reqRegistry,
*this,
*getRequestContext(),
std::move(rctx),
getProtoId(),
std::move(debugPayload),
stateMachine_);
scheduleTimeouts();
}
void ThriftServerRequestResponse::sendThriftResponse(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::MessageChannel::SendCallbackPtr cb) noexcept {
if (auto error = processFirstResponse(
metadata, data, getProtoId(), version_, getCompressionConfig())) {
error.handle(
[&](RocketException& ex) {
context_.sendError(std::move(ex), std::move(cb));
},
[&](...) { sendErrorWrapped(std::move(error), kUnknownErrorCode); });
return;
}
context_.sendPayload(
pack(metadata, std::move(data)),
Flags().next(true).complete(true),
std::move(cb));
}
void ThriftServerRequestResponse::sendThriftException(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::MessageChannel::SendCallbackPtr cb) noexcept {
sendThriftResponse(std::move(metadata), std::move(data), std::move(cb));
}
void ThriftServerRequestResponse::sendSerializedError(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> exbuf) noexcept {
sendThriftResponse(std::move(metadata), std::move(exbuf), nullptr);
}
void ThriftServerRequestResponse::closeConnection(
folly::exception_wrapper ew) noexcept {
context_.connection().close(std::move(ew));
}
ThriftServerRequestFnf::ThriftServerRequestFnf(
RequestsRegistry::DebugStub& debugStubToInit,
folly::EventBase& evb,
server::ServerConfigs& serverConfigs,
RequestRpcMetadata&& metadata,
Cpp2ConnContext& connContext,
std::shared_ptr<folly::RequestContext> rctx,
RequestsRegistry& reqRegistry,
rocket::Payload&& debugPayload,
RocketServerFrameContext&& context,
folly::Function<void()> onComplete)
: ThriftRequestCore(serverConfigs, std::move(metadata), connContext),
evb_(evb),
context_(std::move(context)),
onComplete_(std::move(onComplete)) {
new (&debugStubToInit) RequestsRegistry::DebugStub(
reqRegistry,
*this,
*getRequestContext(),
std::move(rctx),
getProtoId(),
std::move(debugPayload),
stateMachine_);
scheduleTimeouts();
}
ThriftServerRequestFnf::~ThriftServerRequestFnf() {
if (auto f = std::move(onComplete_)) {
f();
}
}
void ThriftServerRequestFnf::sendThriftResponse(
ResponseRpcMetadata&&,
std::unique_ptr<folly::IOBuf>,
apache::thrift::MessageChannel::SendCallbackPtr) noexcept {
LOG(FATAL) << "One-way requests cannot send responses";
}
void ThriftServerRequestFnf::sendThriftException(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::MessageChannel::SendCallbackPtr cb) noexcept {
sendThriftResponse(std::move(metadata), std::move(data), std::move(cb));
}
void ThriftServerRequestFnf::sendSerializedError(
ResponseRpcMetadata&&, std::unique_ptr<folly::IOBuf>) noexcept {}
void ThriftServerRequestFnf::closeConnection(
folly::exception_wrapper ew) noexcept {
context_.connection().close(std::move(ew));
}
ThriftServerRequestStream::ThriftServerRequestStream(
RequestsRegistry::DebugStub& debugStubToInit,
folly::EventBase& evb,
server::ServerConfigs& serverConfigs,
RequestRpcMetadata&& metadata,
Cpp2ConnContext& connContext,
std::shared_ptr<folly::RequestContext> rctx,
RequestsRegistry& reqRegistry,
rocket::Payload&& debugPayload,
RocketServerFrameContext&& context,
int32_t version,
RocketStreamClientCallback* clientCallback,
std::shared_ptr<AsyncProcessor> cpp2Processor)
: ThriftRequestCore(serverConfigs, std::move(metadata), connContext),
evb_(evb),
context_(std::move(context)),
version_(version),
clientCallback_(clientCallback),
cpp2Processor_(std::move(cpp2Processor)) {
new (&debugStubToInit) RequestsRegistry::DebugStub(
reqRegistry,
*this,
*getRequestContext(),
std::move(rctx),
getProtoId(),
std::move(debugPayload),
stateMachine_);
if (auto compressionConfig = getCompressionConfig()) {
clientCallback_->setCompressionConfig(*compressionConfig);
}
scheduleTimeouts();
}
void ThriftServerRequestStream::sendThriftResponse(
ResponseRpcMetadata&&,
std::unique_ptr<folly::IOBuf>,
apache::thrift::MessageChannel::SendCallbackPtr) noexcept {
LOG(FATAL) << "Stream requests must respond via sendStreamThriftResponse";
}
void ThriftServerRequestStream::sendThriftException(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::MessageChannel::SendCallbackPtr) noexcept {
sendSerializedError(std::move(metadata), std::move(data));
}
bool ThriftServerRequestStream::sendStreamThriftResponse(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
StreamServerCallbackPtr stream) noexcept {
if (!stream) {
sendSerializedError(std::move(metadata), std::move(data));
return false;
}
if (auto error = processFirstResponse(
metadata, data, getProtoId(), version_, getCompressionConfig())) {
handleStreamError(std::move(error), clientCallback_, this);
return false;
}
context_.unsetMarkRequestComplete();
stream->resetClientCallback(*clientCallback_);
clientCallback_->setProtoId(getProtoId());
return clientCallback_->onFirstResponse(
FirstResponsePayload{std::move(data), std::move(metadata)},
nullptr /* evb */,
stream.release());
}
void ThriftServerRequestStream::sendStreamThriftResponse(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::detail::ServerStreamFactory&& stream) noexcept {
if (!stream) {
sendSerializedError(std::move(metadata), std::move(data));
return;
}
if (auto error = processFirstResponse(
metadata, data, getProtoId(), version_, getCompressionConfig())) {
handleStreamError(std::move(error), clientCallback_, this);
return;
}
context_.unsetMarkRequestComplete();
clientCallback_->setProtoId(getProtoId());
stream(
apache::thrift::FirstResponsePayload{
std::move(data), std::move(metadata)},
clientCallback_,
&evb_);
}
void ThriftServerRequestStream::sendSerializedError(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> exbuf) noexcept {
if (auto error = processFirstResponse(
metadata, exbuf, getProtoId(), version_, getCompressionConfig())) {
handleStreamError(std::move(error), clientCallback_, this);
return;
}
std::exchange(clientCallback_, nullptr)
->onFirstResponseError(folly::make_exception_wrapper<
thrift::detail::EncodedFirstResponseError>(
FirstResponsePayload(std::move(exbuf), std::move(metadata))));
}
void ThriftServerRequestStream::closeConnection(
folly::exception_wrapper ew) noexcept {
context_.connection().close(std::move(ew));
}
ThriftServerRequestSink::ThriftServerRequestSink(
RequestsRegistry::DebugStub& debugStubToInit,
folly::EventBase& evb,
server::ServerConfigs& serverConfigs,
RequestRpcMetadata&& metadata,
Cpp2ConnContext& connContext,
std::shared_ptr<folly::RequestContext> rctx,
RequestsRegistry& reqRegistry,
rocket::Payload&& debugPayload,
RocketServerFrameContext&& context,
int32_t version,
RocketSinkClientCallback* clientCallback,
std::shared_ptr<AsyncProcessor> cpp2Processor)
: ThriftRequestCore(serverConfigs, std::move(metadata), connContext),
evb_(evb),
context_(std::move(context)),
version_(version),
clientCallback_(clientCallback),
cpp2Processor_(std::move(cpp2Processor)) {
new (&debugStubToInit) RequestsRegistry::DebugStub(
reqRegistry,
*this,
*getRequestContext(),
std::move(rctx),
getProtoId(),
std::move(debugPayload),
stateMachine_);
if (auto compressionConfig = getCompressionConfig()) {
clientCallback_->setCompressionConfig(*compressionConfig);
}
scheduleTimeouts();
}
void ThriftServerRequestSink::sendThriftResponse(
ResponseRpcMetadata&&,
std::unique_ptr<folly::IOBuf>,
apache::thrift::MessageChannel::SendCallbackPtr) noexcept {
LOG(FATAL) << "Sink requests must respond via sendSinkThriftResponse";
}
void ThriftServerRequestSink::sendThriftException(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::MessageChannel::SendCallbackPtr) noexcept {
sendSerializedError(std::move(metadata), std::move(data));
}
void ThriftServerRequestSink::sendSerializedError(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> exbuf) noexcept {
if (auto error = processFirstResponse(
metadata, exbuf, getProtoId(), version_, getCompressionConfig())) {
handleStreamError(std::move(error), clientCallback_, this);
return;
}
std::exchange(clientCallback_, nullptr)
->onFirstResponseError(folly::make_exception_wrapper<
thrift::detail::EncodedFirstResponseError>(
FirstResponsePayload(std::move(exbuf), std::move(metadata))));
}
#if FOLLY_HAS_COROUTINES
void ThriftServerRequestSink::sendSinkThriftResponse(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
apache::thrift::detail::SinkConsumerImpl&& sinkConsumer) noexcept {
if (!sinkConsumer) {
sendSerializedError(std::move(metadata), std::move(data));
return;
}
if (auto error = processFirstResponse(
metadata, data, getProtoId(), version_, getCompressionConfig())) {
handleStreamError(std::move(error), clientCallback_, this);
return;
}
context_.unsetMarkRequestComplete();
auto* executor = sinkConsumer.executor.get();
clientCallback_->setProtoId(getProtoId());
clientCallback_->setChunkTimeout(sinkConsumer.chunkTimeout);
auto serverCallback = apache::thrift::detail::ServerSinkBridge::create(
std::move(sinkConsumer), *getEventBase(), clientCallback_);
clientCallback_->onFirstResponse(
FirstResponsePayload{std::move(data), std::move(metadata)},
nullptr /* evb */,
serverCallback.get());
folly::coro::co_invoke(
&apache::thrift::detail::ServerSinkBridge::start,
std::move(serverCallback))
.scheduleOn(executor)
.start();
}
bool ThriftServerRequestSink::sendSinkThriftResponse(
ResponseRpcMetadata&& metadata,
std::unique_ptr<folly::IOBuf> data,
SinkServerCallbackPtr serverCallback) noexcept {
if (!serverCallback) {
sendSerializedError(std::move(metadata), std::move(data));
return false;
}
if (auto error = processFirstResponse(
metadata, data, getProtoId(), version_, getCompressionConfig())) {
handleStreamError(std::move(error), clientCallback_, this);
return false;
}
context_.unsetMarkRequestComplete();
serverCallback->resetClientCallback(*clientCallback_);
clientCallback_->setProtoId(getProtoId());
return clientCallback_->onFirstResponse(
FirstResponsePayload{std::move(data), std::move(metadata)},
nullptr, /* evb */
serverCallback.release());
}
#endif
void ThriftServerRequestSink::closeConnection(
folly::exception_wrapper ew) noexcept {
context_.connection().close(std::move(ew));
}
} // namespace rocket
} // namespace thrift
} // namespace apache