thrift/lib/cpp2/server/Cpp2Connection.cpp (824 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <thrift/lib/cpp2/server/Cpp2Connection.h> #include <folly/Overload.h> #include <thrift/lib/cpp/transport/THeader.h> #include <thrift/lib/cpp2/Flags.h> #include <thrift/lib/cpp2/GeneratedCodeHelper.h> #include <thrift/lib/cpp2/async/AsyncProcessorHelper.h> #include <thrift/lib/cpp2/async/ResponseChannel.h> #include <thrift/lib/cpp2/protocol/BinaryProtocol.h> #include <thrift/lib/cpp2/protocol/CompactProtocol.h> #include <thrift/lib/cpp2/server/Cpp2Worker.h> #include <thrift/lib/cpp2/server/LoggingEventHelper.h> #include <thrift/lib/cpp2/server/MonitoringMethodNames.h> #include <thrift/lib/cpp2/server/ThriftServer.h> #include <thrift/lib/cpp2/transport/rocket/server/RocketRoutingHandler.h> THRIFT_FLAG_DEFINE_bool(server_rocket_upgrade_enabled, false); THRIFT_FLAG_DEFINE_bool(server_header_reject_http, true); THRIFT_FLAG_DEFINE_bool(server_header_reject_framed, true); THRIFT_FLAG_DEFINE_bool(server_header_reject_unframed, true); THRIFT_FLAG_DEFINE_int64(monitoring_over_header_logging_sample_rate, 1'000'000); namespace apache { namespace thrift { using namespace std; namespace { // This is a SendCallback used for transport upgrade from header to rocket class TransportUpgradeSendCallback : public MessageChannel::SendCallback { public: TransportUpgradeSendCallback( const std::shared_ptr<folly::AsyncTransport>& transport, const folly::SocketAddress* peerAddress, Cpp2Worker* cpp2Worker, Cpp2Connection* cpp2Conn, HeaderServerChannel* headerChannel) : transport_(transport), peerAddress_(peerAddress), cpp2Worker_(cpp2Worker), cpp2Conn_(cpp2Conn), headerChannel_(headerChannel) {} void sendQueued() override {} void messageSent() override { SCOPE_EXIT { delete this; }; // do the transport upgrade for (auto& routingHandler : *cpp2Worker_->getServer()->getRoutingHandlers()) { if (auto handler = dynamic_cast<RocketRoutingHandler*>(routingHandler.get())) { // Close the channel, since the transport is transferring to rocket DCHECK(headerChannel_); headerChannel_->setCallback(nullptr); headerChannel_->setTransport(nullptr); headerChannel_->closeNow(); DCHECK(transport_.use_count() == 1); // Only do upgrade if transport_ is the only one managing the socket. // Otherwise close the connection. if (transport_.use_count() == 1) { // Steal the transport from header channel auto uPtr = std::get_deleter< apache::thrift::transport::detail::ReleaseDeleter< folly::AsyncTransport, folly::DelayedDestruction::Destructor>>(transport_) ->stealPtr(); // Let RocketRoutingHandler handle the connection from here handler->handleConnection( cpp2Worker_->getConnectionManager(), std::move(uPtr), peerAddress_, wangle::TransportInfo(), cpp2Worker_->getWorkerShared()); } DCHECK(cpp2Conn_); cpp2Conn_->stop(); break; } } } void messageSendError(folly::exception_wrapper&&) override { delete this; } private: const std::shared_ptr<folly::AsyncTransport>& transport_; const folly::SocketAddress* peerAddress_; Cpp2Worker* cpp2Worker_; Cpp2Connection* cpp2Conn_; HeaderServerChannel* headerChannel_; }; } // namespace Cpp2Connection::Cpp2Connection( const std::shared_ptr<folly::AsyncTransport>& transport, const folly::SocketAddress* address, std::shared_ptr<Cpp2Worker> worker, const std::shared_ptr<HeaderServerChannel>& serverChannel) : processorFactory_(worker->getServer()->getDecoratedProcessorFactory()), serviceMetadata_(worker->getMetadataForService(processorFactory_)), serviceRequestInfoMap_(processorFactory_.getServiceRequestInfoMap()), processor_(processorFactory_.getProcessor()), duplexChannel_( worker->getServer()->isDuplex() ? std::make_unique<DuplexChannel>( DuplexChannel::Who::SERVER, transport) : nullptr), channel_( serverChannel ? serverChannel : // used by client duplexChannel_ ? duplexChannel_->getServerChannel() : // server std::shared_ptr<HeaderServerChannel>( new HeaderServerChannel(transport), folly::DelayedDestruction::Destructor())), worker_(std::move(worker)), context_( address, transport.get(), worker_->getServer()->getEventBaseManager(), duplexChannel_ ? duplexChannel_->getClientChannel() : nullptr, nullptr, worker_->getServer()->getClientIdentityHook(), worker_.get()), transport_(transport), executor_(worker_->getServer()->getExecutor()) { if (!useResourcePoolsFlagsSet()) { threadManager_ = worker_->getServer()->getThreadManager(); } context_.setTransportType(Cpp2ConnContext::TransportType::HEADER); if (auto* observer = worker_->getServer()->getObserver()) { channel_->setSampleRate(observer->getSampleRate()); } for (const auto& handler : worker_->getServer()->getEventHandlersUnsafe()) { handler->newConnection(&context_); } } Cpp2Connection::~Cpp2Connection() { for (const auto& handler : worker_->getServer()->getEventHandlersUnsafe()) { handler->connectionDestroyed(&context_); } if (connectionAdded_) { if (auto* observer = worker_->getServer()->getObserver()) { observer->connClosed(); } } channel_.reset(); } void Cpp2Connection::stop() { if (getConnectionManager()) { getConnectionManager()->removeConnection(this); } context_.connectionClosed(); for (auto req : activeRequests_) { VLOG(1) << "Task killed due to channel close: " << context_.getPeerAddress()->describe(); if (!req->isOneway()) { req->cancelRequest(); if (auto* observer = worker_->getServer()->getObserver()) { observer->taskKilled(); } } } if (channel_) { channel_->setCallback(nullptr); // Release the socket to avoid long CLOSE_WAIT times channel_->closeNow(); } transport_.reset(); this_.reset(); } void Cpp2Connection::timeoutExpired() noexcept { // Only disconnect if there are no active requests. No need to set another // timeout here because it's going to be set when all the requests are // handled. if (activeRequests_.empty()) { disconnect("idle timeout"); } } void Cpp2Connection::disconnect(const char* comment) noexcept { // This must be the last call, it may delete this. auto guard = folly::makeGuard([&] { stop(); }); VLOG(1) << "ERROR: Disconnect: " << comment << " on channel: " << context_.getPeerAddress()->describe(); if (auto* observer = worker_->getServer()->getObserver()) { observer->connDropped(); } } void Cpp2Connection::setServerHeaders( transport::THeader::StringToStringMap& writeHeaders) { if (getWorker()->isStopping()) { writeHeaders["connection"] = "goaway"; } } void Cpp2Connection::setServerHeaders( HeaderServerChannel::HeaderRequest& request) { auto& writeHeaders = request.getHeader()->mutableWriteHeaders(); setServerHeaders(writeHeaders); const auto& readHeaders = request.getHeader()->getHeaders(); auto ptr = folly::get_ptr(readHeaders, THeader::QUERY_LOAD_HEADER); if (ptr) { auto load = getWorker()->getServer()->getLoad(*ptr); writeHeaders[THeader::QUERY_LOAD_HEADER] = folly::to<std::string>(load); } } void Cpp2Connection::requestTimeoutExpired() { VLOG(1) << "ERROR: Task expired on channel: " << context_.getPeerAddress()->describe(); if (auto* observer = worker_->getServer()->getObserver()) { observer->taskTimeout(); } } void Cpp2Connection::queueTimeoutExpired() { VLOG(1) << "ERROR: Queue timeout on channel: " << context_.getPeerAddress()->describe(); if (auto* observer = worker_->getServer()->getObserver()) { observer->queueTimeout(); } } bool Cpp2Connection::pending() { return transport_ ? transport_->isPending() : false; } void Cpp2Connection::handleAppError( std::unique_ptr<HeaderServerChannel::HeaderRequest> req, const std::string& name, const std::string& message, bool isClientError) { static const std::string headerEx = "uex"; static const std::string headerExWhat = "uexw"; req->getHeader()->setHeader(headerEx, name); req->getHeader()->setHeader(headerExWhat, message); killRequest( std::move(req), TApplicationException::UNKNOWN, isClientError ? kAppClientErrorCode : kAppServerErrorCode, message.c_str()); } void Cpp2Connection::killRequest( std::unique_ptr<HeaderServerChannel::HeaderRequest> req, TApplicationException::TApplicationExceptionType reason, const std::string& errorCode, const char* comment) { VLOG(1) << "ERROR: Task killed: " << comment << ": " << context_.getPeerAddress()->getAddressStr(); auto server = worker_->getServer(); if (auto* observer = server->getObserver()) { if (reason == TApplicationException::TApplicationExceptionType::LOADSHEDDING) { observer->serverOverloaded(); } else { observer->taskKilled(); } } // Nothing to do for Thrift oneway request. if (req->isOneway()) { return; } setServerHeaders(*req); req->sendErrorWrapped( folly::make_exception_wrapper<TApplicationException>(reason, comment), errorCode); } // Response Channel callbacks void Cpp2Connection::requestReceived( unique_ptr<HeaderServerChannel::HeaderRequest>&& hreq) { auto& samplingStatus = hreq->getSamplingStatus(); std::chrono::steady_clock::time_point readEnd; if (samplingStatus.isEnabled()) { readEnd = std::chrono::steady_clock::now(); } bool useHttpHandler = false; // Any POST not for / should go to the status handler if (hreq->getHeader()->getClientType() == THRIFT_HTTP_SERVER_TYPE) { auto buf = hreq->getBuf(); // 7 == length of "POST / " - we are matching on the path if (buf->length() >= 7 && 0 == strncmp(reinterpret_cast<const char*>(buf->data()), "POST", 4) && buf->data()[6] != ' ') { useHttpHandler = true; } // Any GET should use the handler if (buf->length() >= 3 && 0 == strncmp(reinterpret_cast<const char*>(buf->data()), "GET", 3)) { useHttpHandler = true; } // Any HEAD should use the handler if (buf->length() >= 4 && 0 == strncmp(reinterpret_cast<const char*>(buf->data()), "HEAD", 4)) { useHttpHandler = true; } } if (useHttpHandler && worker_->getServer()->getGetHandler()) { worker_->getServer()->getGetHandler()( worker_->getEventBase(), worker_->getConnectionManager(), transport_, hreq->extractBuf()); // Close the channel, since the handler now owns the socket. channel_->setCallback(nullptr); channel_->setTransport(nullptr); stop(); return; } if (THRIFT_FLAG(server_header_reject_http) && (hreq->getHeader()->getClientType() == THRIFT_HTTP_SERVER_TYPE || hreq->getHeader()->getClientType() == THRIFT_HTTP_CLIENT_TYPE || hreq->getHeader()->getClientType() == THRIFT_HTTP_GET_CLIENT_TYPE)) { disconnect("Rejecting HTTP connection over Header"); return; } if (THRIFT_FLAG(server_header_reject_framed) && (hreq->getHeader()->getClientType() == THRIFT_FRAMED_DEPRECATED || hreq->getHeader()->getClientType() == THRIFT_FRAMED_COMPACT)) { disconnect("Rejecting framed connection over Header"); return; } if (THRIFT_FLAG(server_header_reject_unframed) && (hreq->getHeader()->getClientType() == THRIFT_UNFRAMED_DEPRECATED || hreq->getHeader()->getClientType() == THRIFT_UNFRAMED_COMPACT_DEPRECATED)) { disconnect("Rejecting unframed connection over Header"); return; } auto protoId = static_cast<apache::thrift::protocol::PROTOCOL_TYPES>( hreq->getHeader()->getProtocolId()); auto msgBegin = apache::thrift::detail::ap::deserializeMessageBegin( *hreq->getBuf(), protoId); std::string& methodName = msgBegin.methodName; const auto& meta = msgBegin.metadata; // Transport upgrade: check if client requested transport upgrade from header // to rocket. If yes, reply immediately and upgrade the transport after // sending the reply. if (methodName == "upgradeToRocket") { if (THRIFT_FLAG(server_rocket_upgrade_enabled)) { ResponsePayload response; switch (protoId) { case apache::thrift::protocol::T_BINARY_PROTOCOL: response = upgradeToRocketReply<apache::thrift::BinaryProtocolWriter>( meta.seqId); break; case apache::thrift::protocol::T_COMPACT_PROTOCOL: response = upgradeToRocketReply<apache::thrift::CompactProtocolWriter>( meta.seqId); break; default: LOG(DFATAL) << "Unsupported protocol found"; // if protocol is neither binary or compact, we want to kill the // request and abort upgrade killRequest( std::move(hreq), TApplicationException::TApplicationExceptionType:: INVALID_PROTOCOL, kUnknownErrorCode, "invalid protocol used"); return; } hreq->sendReply( std::move(response), new TransportUpgradeSendCallback( transport_, context_.getPeerAddress(), getWorker(), this, channel_.get())); return; } else { killRequest( std::move(hreq), TApplicationException::TApplicationExceptionType::UNKNOWN_METHOD, kMethodUnknownErrorCode, "Rocket upgrade disabled"); return; } } if (worker_->getServer()->isHeaderDisabled()) { disconnect("Rejecting Header connection"); return; } using PerServiceMetadata = Cpp2Worker::PerServiceMetadata; const PerServiceMetadata::FindMethodResult methodMetadataResult = serviceMetadata_.findMethod(methodName); auto baseReqCtx = serviceMetadata_.getBaseContextForRequest(methodMetadataResult); auto rootid = worker_->getRequestsRegistry()->genRootId(); auto reqCtx = baseReqCtx ? folly::RequestContext::copyAsRoot(*baseReqCtx, rootid) : std::make_shared<folly::RequestContext>(rootid); folly::RequestContextScopeGuard rctx(reqCtx); auto server = worker_->getServer(); server->touchRequestTimestamp(); auto* observer = server->getObserver(); if (observer) { observer->receivedRequest(&methodName); } auto injectedFailure = server->maybeInjectFailure(); switch (injectedFailure) { case ThriftServer::InjectedFailure::NONE: break; case ThriftServer::InjectedFailure::ERROR: killRequest( std::move(hreq), TApplicationException::TApplicationExceptionType::INJECTED_FAILURE, kInjectedFailureErrorCode, "injected failure"); return; case ThriftServer::InjectedFailure::DROP: VLOG(1) << "ERROR: injected drop: " << context_.getPeerAddress()->getAddressStr(); return; case ThriftServer::InjectedFailure::DISCONNECT: disconnect("injected failure"); return; } if (server->getGetHeaderHandler()) { server->getGetHeaderHandler()(hreq->getHeader(), context_.getPeerAddress()); } if (auto overloadResult = server->checkOverload( &hreq->getHeader()->getHeaders(), &methodName)) { killRequest( std::move(hreq), TApplicationException::LOADSHEDDING, overloadResult.value(), "loadshedding request"); return; } if (auto preprocessResult = server->preprocess( {hreq->getHeader()->getHeaders(), methodName, context_}); !std::holds_alternative<std::monostate>(preprocessResult)) { folly::variant_match( preprocessResult, [&](AppClientException& ace) { handleAppError(std::move(hreq), ace.name(), ace.getMessage(), true); }, [&](AppOverloadedException& aoe) { killRequest( std::move(hreq), TApplicationException::LOADSHEDDING, kAppOverloadedErrorCode, aoe.getMessage().c_str()); }, [&](AppServerException& ase) { handleAppError(std::move(hreq), ase.name(), ase.getMessage(), false); }, [](std::monostate&) { folly::assume_unreachable(); }); return; } if (worker_->isStopping()) { killRequest( std::move(hreq), TApplicationException::TApplicationExceptionType::INTERNAL_ERROR, kQueueOverloadedErrorCode, "server shutting down"); return; } if (!server->shouldHandleRequestForMethod(methodName)) { killRequest( std::move(hreq), TApplicationException::TApplicationExceptionType::INTERNAL_ERROR, kQueueOverloadedErrorCode, "server not ready"); return; } // After this, the request buffer is no longer owned by the request // and will be released after deserializeRequest. auto serializedRequest = [&] { folly::IOBufQueue bufQueue; bufQueue.append(hreq->extractBuf()); bufQueue.trimStart(meta.size); return SerializedRequest(bufQueue.move()); }(); // We keep a clone of the request payload buffer for debugging purposes, but // the lifetime of payload should not necessarily be the same as its request // object's. auto debugPayload = rocket::Payload::makeCombined(serializedRequest.buffer->clone(), 0); std::chrono::milliseconds queueTimeout; std::chrono::milliseconds taskTimeout; std::chrono::milliseconds clientQueueTimeout = hreq->getHeader()->getClientQueueTimeout(); std::chrono::milliseconds clientTimeout = hreq->getHeader()->getClientTimeout(); auto differentTimeouts = server->getTaskExpireTimeForRequest( clientQueueTimeout, clientTimeout, queueTimeout, taskTimeout); folly::call_once(clientInfoFlag_, [&] { if (const auto& m = hreq->getHeader()->extractClientMetadata()) { context_.setClientMetadata(*m); } }); context_.setClientType(hreq->getHeader()->getClientType()); auto t2r = RequestsRegistry::makeRequest<Cpp2Request>( std::move(hreq), std::move(reqCtx), this_, std::move(debugPayload), std::move(methodName)); server->incActiveRequests(); if (samplingStatus.isEnabled()) { // Expensive operations; happens only when sampling is enabled auto& timestamps = t2r->getTimestamps(); timestamps.setStatus(samplingStatus); timestamps.readEnd = readEnd; timestamps.processBegin = std::chrono::steady_clock::now(); if (samplingStatus.isEnabledByServer() && observer) { if (threadManager_) { observer->queuedRequests(threadManager_->pendingUpstreamTaskCount()); } observer->activeRequests(server->getActiveRequests()); } } activeRequests_.insert(t2r.get()); auto reqContext = t2r->getContext(); if (observer) { observer->admittedRequest(&reqContext->getMethodName()); } if (differentTimeouts) { if (queueTimeout > std::chrono::milliseconds(0)) { scheduleTimeout(&t2r->queueTimeout_, queueTimeout); } } if (taskTimeout > std::chrono::milliseconds(0)) { scheduleTimeout(&t2r->taskTimeout_, taskTimeout); } if (clientTimeout > std::chrono::milliseconds::zero()) { reqContext->setRequestTimeout(clientTimeout); } else { reqContext->setRequestTimeout(taskTimeout); } // Log monitoring methods that are called over header interface so that they // can be migrated to rocket monitoring interface. LoggingSampler monitoringLogSampler{ THRIFT_FLAG(monitoring_over_header_logging_sample_rate)}; if (monitoringLogSampler.isSampled()) { if (isMonitoringMethodName(reqContext->getMethodName())) { THRIFT_CONNECTION_EVENT(monitoring_over_header) .logSampled(context_, monitoringLogSampler, [&] { return folly::dynamic::object( "method_name", reqContext->getMethodName()); }); } } try { ResponseChannelRequest::UniquePtr req = std::move(t2r); if (!apache::thrift::detail::ap::setupRequestContextWithMessageBegin( meta, protoId, req, reqContext, worker_->getEventBase())) { return; } folly::variant_match( methodMetadataResult, [&](PerServiceMetadata::MetadataNotImplemented) { logSetupConnectionEventsOnce(setupLoggingFlag_, context_); // The AsyncProcessorFactory does not implement createMethodMetadata // so we need to fallback to processSerializedCompressedRequest. processor_->processSerializedCompressedRequest( std::move(req), SerializedCompressedRequest(std::move(serializedRequest)), protoId, reqContext, worker_->getEventBase(), threadManager_.get()); }, [&](PerServiceMetadata::MetadataNotFound) { AsyncProcessorHelper::sendUnknownMethodError( std::move(req), reqContext->getMethodName()); }, [&](const PerServiceMetadata::MetadataFound& found) { logSetupConnectionEventsOnce(setupLoggingFlag_, context_); if (!server->resourcePoolSet().empty()) { // We need to process this using request pools const ServiceRequestInfo* serviceRequestInfo{nullptr}; if (serviceRequestInfoMap_) { serviceRequestInfo = &serviceRequestInfoMap_->get().at( reqContext->getMethodName()); } ServerRequest serverRequest( std::move(req), SerializedCompressedRequest(std::move(serializedRequest)), worker_->getEventBase(), reqContext, protoId, folly::RequestContext::saveContext(), processor_.get(), &found.metadata, serviceRequestInfo); auto poolResult = AsyncProcessorHelper::selectResourcePool( serverRequest, found.metadata); if (auto* reject = std::get_if<ServerRequestRejection>(&poolResult)) { auto errorCode = kAppOverloadedErrorCode; if (reject->applicationException().getType() == TApplicationException::UNKNOWN_METHOD) { errorCode = kMethodUnknownErrorCode; } serverRequest.request()->sendErrorWrapped( folly::exception_wrapper( folly::in_place, std::move(*reject).applicationException()), errorCode); return; } auto resourcePoolHandle = std::get_if<std::reference_wrapper<const ResourcePoolHandle>>( &poolResult); DCHECK( server->resourcePoolSet().hasResourcePool(*resourcePoolHandle)); auto& resourcePool = server->resourcePoolSet().resourcePool(*resourcePoolHandle); apache::thrift::detail::ServerRequestHelper::setExecutor( serverRequest, resourcePool.executor().value_or(nullptr)); auto result = resourcePool.accept(std::move(serverRequest)); if (result) { auto errorCode = kQueueOverloadedErrorCode; serverRequest.request()->sendErrorWrapped( folly::exception_wrapper( folly::in_place, std::move(std::move(result).value()) .applicationException()), errorCode); return; } } else { processor_->processSerializedCompressedRequestWithMetadata( std::move(req), SerializedCompressedRequest(std::move(serializedRequest)), found.metadata, protoId, reqContext, worker_->getEventBase(), threadManager_.get()); } }); } catch (...) { LOG(DFATAL) << "AsyncProcessor::process exception: " << folly::exceptionStr(std::current_exception()); } } void Cpp2Connection::channelClosed(folly::exception_wrapper&& ex) { // This must be the last call, it may delete this. auto guard = folly::makeGuard([&] { stop(); }); VLOG(4) << "Channel " << context_.getPeerAddress()->describe() << " closed: " << ex.what(); } void Cpp2Connection::removeRequest(Cpp2Request* req) { activeRequests_.erase(req); if (activeRequests_.empty()) { resetTimeout(); } } Cpp2Connection::Cpp2Request::Cpp2Request( RequestsRegistry::DebugStub& debugStubToInit, std::unique_ptr<HeaderServerChannel::HeaderRequest> req, std::shared_ptr<folly::RequestContext> rctx, std::shared_ptr<Cpp2Connection> con, rocket::Payload&& debugPayload, std::string&& methodName) : req_(std::move(req)), connection_(std::move(con)), // Note: tricky ordering here; see the note on connection_ in the class // definition. reqContext_( &connection_->context_, req_->getHeader(), std::move(methodName)), stateMachine_( util::includeInRecentRequestsCount(reqContext_.getMethodName()), connection_->getWorker() ->getServer() ->getAdaptiveConcurrencyController()), activeRequestsGuard_(connection_->getWorker()->getActiveRequestsGuard()) { new (&debugStubToInit) RequestsRegistry::DebugStub( *connection_->getWorker()->getRequestsRegistry(), *this, reqContext_, std::move(rctx), protocol::PROTOCOL_TYPES(req_->getHeader()->getProtocolId()), std::move(debugPayload), stateMachine_); queueTimeout_.request_ = this; taskTimeout_.request_ = this; } MessageChannel::SendCallback* Cpp2Connection::Cpp2Request::prepareSendCallback( MessageChannel::SendCallback* sendCallback, apache::thrift::server::TServerObserver* observer) { // If we are sampling this call, wrap it with a Cpp2Sample, which also // implements MessageChannel::SendCallback. Callers of sendReply/sendError // are responsible for cleaning up their own callbacks. MessageChannel::SendCallback* cb = sendCallback; auto& timestamps = getTimestamps(); if (stateMachine_.getStartedProcessing() && timestamps.getSamplingStatus().isEnabledByServer()) { // Cpp2Sample will delete itself when it's callback is called. cb = new Cpp2Sample(timestamps, observer, sendCallback); } return cb; } void Cpp2Connection::Cpp2Request::sendReply( ResponsePayload&& response, MessageChannel::SendCallback* sendCallback, folly::Optional<uint32_t>) { if (tryCancel()) { connection_->setServerHeaders(*req_); markProcessEnd(); auto* observer = connection_->getWorker()->getServer()->getObserver(); auto maxResponseSize = connection_->getWorker()->getServer()->getMaxResponseSize(); if (maxResponseSize != 0 && response.length() > maxResponseSize) { req_->sendErrorWrapped( folly::make_exception_wrapper<TApplicationException>( TApplicationException::TApplicationExceptionType::INTERNAL_ERROR, "Response size too big"), kResponseTooBigErrorCode, reqContext_.getMethodName(), reqContext_.getProtoSeqId(), prepareSendCallback(sendCallback, observer)); } else { req_->sendReply( std::move(response), prepareSendCallback(sendCallback, observer)); } cancelTimeout(); if (observer) { observer->sentReply(); } } } void Cpp2Connection::Cpp2Request::sendException( ResponsePayload&& response, MessageChannel::SendCallback* sendCallback) { if (tryCancel()) { connection_->setServerHeaders(*req_); markProcessEnd(); auto* observer = connection_->getWorker()->getServer()->getObserver(); auto maxResponseSize = connection_->getWorker()->getServer()->getMaxResponseSize(); if (maxResponseSize != 0 && response.length() > maxResponseSize) { req_->sendErrorWrapped( folly::make_exception_wrapper<TApplicationException>( TApplicationException::TApplicationExceptionType::INTERNAL_ERROR, "Response size too big"), kResponseTooBigErrorCode, reqContext_.getMethodName(), reqContext_.getProtoSeqId(), prepareSendCallback(sendCallback, observer)); } else { req_->sendException( std::move(response), prepareSendCallback(sendCallback, observer)); } cancelTimeout(); if (observer) { observer->sentReply(); } } } void Cpp2Connection::Cpp2Request::sendErrorWrapped( folly::exception_wrapper ew, std::string exCode) { if (tryCancel()) { connection_->setServerHeaders(*req_); markProcessEnd(); auto* observer = connection_->getWorker()->getServer()->getObserver(); req_->sendErrorWrapped( std::move(ew), std::move(exCode), reqContext_.getMethodName(), reqContext_.getProtoSeqId(), prepareSendCallback(nullptr, observer)); cancelTimeout(); } } void Cpp2Connection::Cpp2Request::sendTimeoutResponse( HeaderServerChannel::HeaderRequest::TimeoutResponseType responseType) { if (!tryCancel()) { // Timeout was not properly cancelled when request was previously // cancelled DCHECK(false); } auto* observer = connection_->getWorker()->getServer()->getObserver(); transport::THeader::StringToStringMap headers; connection_->setServerHeaders(headers); markProcessEnd(&headers); req_->sendTimeoutResponse( reqContext_.getMethodName(), reqContext_.getProtoSeqId(), prepareSendCallback(nullptr, observer), headers, responseType); cancelTimeout(); } void Cpp2Connection::Cpp2Request::sendQueueTimeoutResponse() { sendTimeoutResponse( HeaderServerChannel::HeaderRequest::TimeoutResponseType::QUEUE); connection_->queueTimeoutExpired(); } void Cpp2Connection::Cpp2Request::TaskTimeout::timeoutExpired() noexcept { request_->sendTimeoutResponse( HeaderServerChannel::HeaderRequest::TimeoutResponseType::TASK); request_->connection_->requestTimeoutExpired(); } void Cpp2Connection::Cpp2Request::QueueTimeout::timeoutExpired() noexcept { if (request_->stateMachine_.tryStopProcessing()) { request_->sendQueueTimeoutResponse(); } } void Cpp2Connection::Cpp2Request::markProcessEnd( transport::THeader::StringToStringMap* newHeaders) { auto& timestamps = getTimestamps(); auto& samplingStatus = timestamps.getSamplingStatus(); if (samplingStatus.isEnabled()) { timestamps.processEnd = std::chrono::steady_clock::now(); if (samplingStatus.isEnabledByClient()) { // Latency headers are set after processEnd itself. Can't be // done after write, since headers transform happens during write. setLatencyHeaders(timestamps, newHeaders); } } } void Cpp2Connection::Cpp2Request::setLatencyHeaders( const apache::thrift::server::TServerObserver::CallTimestamps& timestamps, transport::THeader::StringToStringMap* newHeaders) const { if (auto v = timestamps.processDelayLatencyUsec()) { setLatencyHeader( kQueueLatencyHeader.str(), folly::to<std::string>(*v), newHeaders); } if (auto v = timestamps.processLatencyUsec()) { setLatencyHeader( kProcessLatencyHeader.str(), folly::to<std::string>(*v), newHeaders); } } void Cpp2Connection::Cpp2Request::setLatencyHeader( const std::string& key, const std::string& value, transport::THeader::StringToStringMap* newHeaders) const { // newHeaders is used timeout exceptions, where req->header cannot be // mutated. if (newHeaders) { (*newHeaders)[key] = value; } else { req_->getHeader()->setHeader(key, value); } } Cpp2Connection::Cpp2Request::~Cpp2Request() { connection_->removeRequest(this); cancelTimeout(); connection_->getWorker()->getServer()->decActiveRequests(); } void Cpp2Connection::Cpp2Request::cancelRequest() { if (tryCancel()) { cancelTimeout(); } } Cpp2Connection::Cpp2Sample::Cpp2Sample( apache::thrift::server::TServerObserver::CallTimestamps& timestamps, apache::thrift::server::TServerObserver* observer, MessageChannel::SendCallback* chainedCallback) : timestamps_(timestamps), observer_(observer), chainedCallback_(chainedCallback) { DCHECK(observer != nullptr); } void Cpp2Connection::Cpp2Sample::sendQueued() { if (chainedCallback_ != nullptr) { chainedCallback_->sendQueued(); } timestamps_.writeBegin = std::chrono::steady_clock::now(); } void Cpp2Connection::Cpp2Sample::messageSent() { if (chainedCallback_ != nullptr) { chainedCallback_->messageSent(); } timestamps_.writeEnd = std::chrono::steady_clock::now(); delete this; } void Cpp2Connection::Cpp2Sample::messageSendError( folly::exception_wrapper&& e) { if (chainedCallback_ != nullptr) { chainedCallback_->messageSendError(std::move(e)); } timestamps_.writeEnd = std::chrono::steady_clock::now(); delete this; } Cpp2Connection::Cpp2Sample::~Cpp2Sample() { if (observer_) { observer_->callCompleted(timestamps_); } } } // namespace thrift } // namespace apache