thrift/lib/cpp2/transport/rocket/server/RocketServerConnection.cpp (894 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <thrift/lib/cpp2/transport/rocket/server/RocketServerConnection.h> #include <memory> #include <utility> #include <fmt/core.h> #include <folly/ExceptionString.h> #include <folly/ExceptionWrapper.h> #include <folly/GLog.h> #include <folly/Likely.h> #include <folly/MapUtil.h> #include <folly/Overload.h> #include <folly/ScopeGuard.h> #include <folly/SocketAddress.h> #include <folly/Utility.h> #include <folly/dynamic.h> #include <folly/io/Cursor.h> #include <folly/io/IOBuf.h> #include <folly/io/SocketOptionMap.h> #include <folly/io/async/AsyncTransport.h> #include <folly/io/async/DelayedDestruction.h> #include <wangle/acceptor/ConnectionManager.h> #include <thrift/lib/cpp/TApplicationException.h> #include <thrift/lib/cpp2/server/LoggingEvent.h> #include <thrift/lib/cpp2/transport/rocket/PayloadUtils.h> #include <thrift/lib/cpp2/transport/rocket/RocketException.h> #include <thrift/lib/cpp2/transport/rocket/framing/Frames.h> #include <thrift/lib/cpp2/transport/rocket/framing/Util.h> #include <thrift/lib/cpp2/transport/rocket/server/RocketServerFrameContext.h> #include <thrift/lib/cpp2/transport/rocket/server/RocketServerHandler.h> #include <thrift/lib/cpp2/transport/rocket/server/RocketSinkClientCallback.h> #include <thrift/lib/cpp2/transport/rocket/server/RocketStreamClientCallback.h> namespace apache { namespace thrift { namespace rocket { constexpr std::chrono::seconds RocketServerConnection::SocketDrainer::kTimeout; RocketServerConnection::RocketServerConnection( folly::AsyncTransport::UniquePtr socket, std::unique_ptr<RocketServerHandler> frameHandler, MemoryTracker& ingressMemoryTracker, MemoryTracker& egressMemoryTracker, const Config& cfg) : evb_(*socket->getEventBase()), socket_(std::move(socket)), rawSocket_( socket_ ? socket_->getUnderlyingTransport<folly::AsyncSocket>() : nullptr), frameHandler_(std::move(frameHandler)), streamStarvationTimeout_(cfg.streamStarvationTimeout), egressBufferBackpressureThreshold_(cfg.egressBufferBackpressureThreshold), egressBufferRecoverySize_( cfg.egressBufferBackpressureThreshold * cfg.egressBufferBackpressureRecoveryFactor), allocIOBufFnPtr_(cfg.allocIOBufFnPtr), writeBatcher_( *this, cfg.writeBatchingInterval, cfg.writeBatchingSize, cfg.writeBatchingByteSize), socketDrainer_(*this), ingressMemoryTracker_(ingressMemoryTracker), egressMemoryTracker_(egressMemoryTracker) { CHECK(socket_); CHECK(frameHandler_); socket_->setReadCB(&parser_); if (rawSocket_) { rawSocket_->setBufferCallback(this); rawSocket_->setSendTimeout(cfg.socketWriteTimeout.count()); if (cfg.socketOptions != nullptr) { auto sockfd = rawSocket_->getNetworkSocket(); for (auto& [option, value] : *cfg.socketOptions) { if (auto err = option.apply(sockfd, value)) { folly::SocketAddress address; rawSocket_->getAddress(&address); FB_LOG_EVERY_MS(WARNING, 60 * 1000) << fmt::format( "Could not apply SocketOption(level={}, optname={}, value={}) to socket {}", option.level, option.optname, value, address.describe()); } } } } } std::unique_ptr<folly::IOBuf> RocketServerConnection::customAlloc(size_t size) { if (allocIOBufFnPtr_ && *allocIOBufFnPtr_) { return (*allocIOBufFnPtr_)(size); } return nullptr; } RocketStreamClientCallback& RocketServerConnection::createStreamClientCallback( StreamId streamId, RocketServerConnection& connection, uint32_t initialRequestN) { auto callback = std::make_unique<RocketStreamClientCallback>( streamId, connection, initialRequestN); auto& callbackRef = *callback; streams_.emplace(streamId, std::move(callback)); return callbackRef; } RocketSinkClientCallback& RocketServerConnection::createSinkClientCallback( StreamId streamId, RocketServerConnection& connection) { auto callback = std::make_unique<RocketSinkClientCallback>(streamId, connection); auto& callbackRef = *callback; streams_.emplace(streamId, std::move(callback)); return callbackRef; } void RocketServerConnection::flushWrites( std::unique_ptr<folly::IOBuf> writes, WriteBatchContext&& context) { DestructorGuard dg(this); DVLOG(10) << fmt::format("write: {} B", writes->computeChainDataLength()); inflightWritesQueue_.push_back(std::move(context)); socket_->writeChain(this, std::move(writes)); } void RocketServerConnection::send( std::unique_ptr<folly::IOBuf> data, apache::thrift::MessageChannel::SendCallbackPtr cb) { evb_.dcheckIsInEventBaseThread(); if (state_ != ConnectionState::ALIVE && state_ != ConnectionState::DRAINING) { return; } writeBatcher_.enqueueWrite(std::move(data), std::move(cb)); } RocketServerConnection::~RocketServerConnection() { DCHECK(inflightRequests_ == 0); DCHECK(inflightWritesQueue_.empty()); DCHECK(inflightSinkFinalResponses_ == 0); DCHECK(writeBatcher_.empty()); DCHECK(activePausedHandlers_ == 0); if (rawSocket_) { rawSocket_->setBufferCallback(nullptr); } // Subtle: Close the socket, which will fail all outstanding writes and // unsubscribe the read callback, but do not destroy the object itself, since // other member variables of RocketServerConnection may be borrowing the // object. socket_->closeNow(); // reclaim any memory in use by pending writes if (egressBufferSize_) { egressMemoryTracker_.decrement(egressBufferSize_); DVLOG(10) << "buffered: 0 (-" << egressBufferSize_ << ") B"; egressBufferSize_ = 0; } } void RocketServerConnection::closeIfNeeded() { if (state_ == ConnectionState::DRAINING && inflightRequests_ == 0 && inflightSinkFinalResponses_ == 0) { DestructorGuard dg(this); socketDrainer_.activate(); if (drainCompleteCode_) { ServerPushMetadata serverMeta; serverMeta.drainCompletePush_ref() .ensure() .drainCompleteCode_ref() .from_optional(drainCompleteCode_); sendMetadataPush(packCompact(std::move(serverMeta))); // Send CONNECTION_ERROR error in case client doesn't support // DrainCompletePush sendError(StreamId{0}, RocketException(ErrorCode::CONNECTION_ERROR)); } state_ = ConnectionState::CLOSING; frameHandler_->connectionClosing(); closeIfNeeded(); return; } if (state_ != ConnectionState::CLOSING) { return; } if (!socket_->good()) { socketDrainer_.drainComplete(); } if (isBusy() || !socketDrainer_.isDrainComplete()) { return; } DestructorGuard dg(this); // Update state_ early, as subsequent lines may call recursively into // closeIfNeeded(). Such recursive calls should be no-ops. state_ = ConnectionState::CLOSED; if (auto* manager = getConnectionManager()) { manager->removeConnection(this); } while (!streams_.empty()) { auto callback = std::move(streams_.begin()->second); streams_.erase(streams_.begin()); // Calling application callback may trigger rehashing. folly::variant_match( callback, [](const std::unique_ptr<RocketStreamClientCallback>& callback) { callback->onStreamCancel(); }, [](const std::unique_ptr<RocketSinkClientCallback>& callback) { bool state = callback->onSinkError(TApplicationException( TApplicationException::TApplicationExceptionType::INTERRUPTION)); DCHECK(state) << "onSinkError called after sink complete!"; }); requestComplete(); } writeBatcher_.drain(); destroy(); } void RocketServerConnection::handleFrame(std::unique_ptr<folly::IOBuf> frame) { DestructorGuard dg(this); if (state_ != ConnectionState::ALIVE && state_ != ConnectionState::DRAINING) { return; } frameHandler_->onBeforeHandleFrame(); folly::io::Cursor cursor(frame.get()); const auto streamId = readStreamId(cursor); FrameType frameType; Flags flags; std::tie(frameType, flags) = readFrameTypeAndFlags(cursor); if (UNLIKELY(!setupFrameReceived_)) { if (frameType != FrameType::SETUP) { return close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID_SETUP, "First frame must be SETUP frame")); } setupFrameReceived_ = true; } else { if (UNLIKELY(frameType == FrameType::SETUP)) { return close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID_SETUP, "More than one SETUP frame received")); } } switch (frameType) { case FrameType::SETUP: { return frameHandler_->handleSetupFrame( SetupFrame(std::move(frame)), *this); } case FrameType::REQUEST_RESPONSE: { return handleRequestFrame( RequestResponseFrame(streamId, flags, cursor, std::move(frame))); } case FrameType::REQUEST_FNF: { return handleRequestFrame( RequestFnfFrame(streamId, flags, cursor, std::move(frame))); } case FrameType::REQUEST_STREAM: { return handleRequestFrame( RequestStreamFrame(streamId, flags, cursor, std::move(frame))); } case FrameType::REQUEST_CHANNEL: { return handleRequestFrame( RequestChannelFrame(streamId, flags, cursor, std::move(frame))); } case FrameType::KEEPALIVE: { if (streamId == StreamId{0}) { KeepAliveFrame keepAliveFrame{std::move(frame)}; if (keepAliveFrame.hasRespondFlag()) { // Echo back data without 'respond' flag send(KeepAliveFrame{Flags(), std::move(keepAliveFrame).data()} .serialize()); } } else { close(folly::make_exception_wrapper<RocketException>( ErrorCode::CONNECTION_ERROR, fmt::format( "Received keepalive frame with non-zero stream ID {}", static_cast<uint32_t>(streamId)))); } return; } // for the rest of frame types, how to deal with them depends on whether // they are part of a streaming or sink (or channel) default: { auto iter = streams_.find(streamId); if (UNLIKELY(iter == streams_.end())) { handleUntrackedFrame( std::move(frame), streamId, frameType, flags, std::move(cursor)); } else { folly::variant_match( iter->second, [&](const std::unique_ptr<RocketStreamClientCallback>& clientCallback) { handleStreamFrame( std::move(frame), streamId, frameType, flags, std::move(cursor), *clientCallback); }, [&](const std::unique_ptr<RocketSinkClientCallback>& clientCallback) { handleSinkFrame( std::move(frame), streamId, frameType, flags, std::move(cursor), *clientCallback); }); } } } } void RocketServerConnection::handleUntrackedFrame( std::unique_ptr<folly::IOBuf> frame, StreamId streamId, FrameType frameType, Flags flags, folly::io::Cursor cursor) { switch (frameType) { case FrameType::PAYLOAD: { auto it = partialRequestFrames_.find(streamId); if (it == partialRequestFrames_.end()) { return; } PayloadFrame payloadFrame(streamId, flags, cursor, std::move(frame)); folly::variant_match(it->second, [&](auto& requestFrame) { const bool hasFollows = payloadFrame.hasFollows(); requestFrame.payload().append(std::move(payloadFrame.payload())); if (!hasFollows) { RocketServerFrameContext(*this, streamId) .onFullFrame(std::move(requestFrame)); partialRequestFrames_.erase(streamId); } }); return; } case FrameType::CANCEL: FOLLY_FALLTHROUGH; case FrameType::REQUEST_N: FOLLY_FALLTHROUGH; case FrameType::ERROR: return; case FrameType::EXT: { ExtFrame extFrame(streamId, flags, cursor, std::move(frame)); switch (extFrame.extFrameType()) { case ExtFrameType::INTERACTION_TERMINATE: { DCHECK_LT(getVersion(), 7); InteractionTerminate term; unpackCompact(term, extFrame.payload().buffer()); frameHandler_->terminateInteraction(term.get_interactionId()); return; } default: if (!extFrame.hasIgnore()) { close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unhandleable ext frame type ({}) without ignore flag", static_cast<uint32_t>(extFrame.extFrameType())))); } return; } } case FrameType::METADATA_PUSH: { MetadataPushFrame metadataFrame(std::move(frame)); ClientPushMetadata clientMeta; try { unpackCompact(clientMeta, metadataFrame.metadata()); } catch (...) { close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, "Failed to deserialize metadata push frame")); return; } switch (clientMeta.getType()) { case ClientPushMetadata::interactionTerminate: { DCHECK_GE(getVersion(), 7); frameHandler_->terminateInteraction( *clientMeta.interactionTerminate_ref()->interactionId_ref()); break; } case ClientPushMetadata::streamHeadersPush: { DCHECK_GE(getVersion(), 7); StreamId sid( clientMeta.streamHeadersPush_ref()->streamId_ref().value_or(0)); auto it = streams_.find(sid); if (it != streams_.end()) { folly::variant_match( it->second, [&](const std::unique_ptr<RocketStreamClientCallback>& clientCallback) { std::ignore = clientCallback->getStreamServerCallback().onSinkHeaders( HeadersPayload(clientMeta.streamHeadersPush_ref() ->headersPayloadContent_ref() .value_or({}))); }, [&](const std::unique_ptr<RocketSinkClientCallback>&) {}); } break; } default: break; } return; } default: close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unhandleable frame type ({})", static_cast<uint8_t>(frameType)))); } } void RocketServerConnection::handleStreamFrame( std::unique_ptr<folly::IOBuf> frame, StreamId streamId, FrameType frameType, Flags flags, folly::io::Cursor cursor, RocketStreamClientCallback& clientCallback) { if (!clientCallback.serverCallbackReady()) { switch (frameType) { case FrameType::CANCEL: { return clientCallback.earlyCancelled(); } default: return close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unexpected early frame, stream id ({}) type ({})", static_cast<uint32_t>(streamId), static_cast<uint8_t>(frameType)))); } } switch (frameType) { case FrameType::REQUEST_N: { RequestNFrame requestNFrame(streamId, flags, cursor); clientCallback.request(requestNFrame.requestN()); return; } case FrameType::CANCEL: { clientCallback.onStreamCancel(); freeStream(streamId, true); return; } case FrameType::EXT: { ExtFrame extFrame(streamId, flags, cursor, std::move(frame)); switch (extFrame.extFrameType()) { case ExtFrameType::HEADERS_PUSH: { DCHECK_LT(getVersion(), 7); auto& serverCallback = clientCallback.getStreamServerCallback(); auto headers = unpack<HeadersPayload>(std::move(extFrame.payload())); if (headers.hasException()) { serverCallback.onStreamCancel(); freeStream(streamId, true); return; } std::ignore = serverCallback.onSinkHeaders(std::move(*headers)); return; } case ExtFrameType::ALIGNED_PAGE: case ExtFrameType::INTERACTION_TERMINATE: case ExtFrameType::CUSTOM_ALLOC: case ExtFrameType::UNKNOWN: if (extFrame.hasIgnore()) { return; } close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unhandleable EXT frame type ({}) for stream (id {})", static_cast<uint32_t>(extFrame.extFrameType()), static_cast<uint32_t>(streamId)))); return; } } default: close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unhandleable frame type ({}) for stream (id {})", static_cast<uint8_t>(frameType), static_cast<uint32_t>(streamId)))); } } void RocketServerConnection::handleSinkFrame( std::unique_ptr<folly::IOBuf> frame, StreamId streamId, FrameType frameType, Flags flags, folly::io::Cursor cursor, RocketSinkClientCallback& clientCallback) { if (!clientCallback.serverCallbackReady()) { switch (frameType) { case FrameType::ERROR: { ErrorFrame errorFrame{std::move(frame)}; if (errorFrame.errorCode() == ErrorCode::CANCELED) { return clientCallback.earlyCancelled(); } } default: return close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unexpected early frame, stream id ({}) type ({})", static_cast<uint32_t>(streamId), static_cast<uint8_t>(frameType)))); } } auto handleSinkPayload = [&](PayloadFrame&& payloadFrame) { const bool next = payloadFrame.hasNext(); const bool complete = payloadFrame.hasComplete(); if (auto fullPayload = bufferOrGetFullPayload(std::move(payloadFrame))) { bool notViolateContract = true; if (next) { auto streamPayload = rocket::unpack<StreamPayload>(std::move(*fullPayload)); if (streamPayload.hasException()) { notViolateContract = clientCallback.onSinkError(std::move(streamPayload.exception())); if (notViolateContract) { freeStream(streamId, true); } } else { auto payloadMetadataRef = streamPayload->metadata.payloadMetadata_ref(); if (payloadMetadataRef && payloadMetadataRef->getType() == PayloadMetadata::exceptionMetadata) { notViolateContract = clientCallback.onSinkError( apache::thrift::detail::EncodedStreamError( std::move(streamPayload.value()))); if (notViolateContract) { freeStream(streamId, true); } } else { notViolateContract = clientCallback.onSinkNext(std::move(*streamPayload)); } } } if (complete) { // it is possible final repsonse(error) sent from serverCallback, // serverCallback may be already destoryed. if (streams_.find(streamId) != streams_.end()) { notViolateContract = clientCallback.onSinkComplete(); } } if (!notViolateContract) { close(folly::make_exception_wrapper<transport::TTransportException>( transport::TTransportException::TTransportExceptionType:: STREAMING_CONTRACT_VIOLATION, "receiving sink payload frame after sink completion")); } } }; switch (frameType) { case FrameType::PAYLOAD: { PayloadFrame payloadFrame(streamId, flags, cursor, std::move(frame)); handleSinkPayload(std::move(payloadFrame)); } break; case FrameType::ERROR: { ErrorFrame errorFrame{std::move(frame)}; auto ew = [&] { if (errorFrame.errorCode() == ErrorCode::CANCELED) { return folly::make_exception_wrapper<TApplicationException>( TApplicationException::TApplicationExceptionType::INTERRUPTION); } else { return folly::make_exception_wrapper<RocketException>( errorFrame.errorCode(), std::move(errorFrame.payload()).data()); } }(); bool notViolateContract = clientCallback.onSinkError(std::move(ew)); if (notViolateContract) { freeStream(streamId, true); } else { close(folly::make_exception_wrapper<transport::TTransportException>( transport::TTransportException::TTransportExceptionType:: STREAMING_CONTRACT_VIOLATION, "receiving sink error frame after sink completion")); } } break; case FrameType::EXT: { ExtFrame extFrame(streamId, flags, cursor, std::move(frame)); auto extFrameType = extFrame.extFrameType(); if (extFrameType == ExtFrameType::ALIGNED_PAGE || extFrameType == ExtFrameType::CUSTOM_ALLOC) { PayloadFrame payloadFrame( streamId, std::move(extFrame.payload()), flags); handleSinkPayload(std::move(payloadFrame)); break; } } default: close(folly::make_exception_wrapper<RocketException>( ErrorCode::INVALID, fmt::format( "Received unhandleable frame type ({}) for sink (id {})", static_cast<uint8_t>(frameType), static_cast<uint32_t>(streamId)))); } } void RocketServerConnection::close(folly::exception_wrapper ew) { if (state_ == ConnectionState::CLOSING || state_ == ConnectionState::CLOSED) { closeIfNeeded(); return; } DestructorGuard dg(this); socketDrainer_.activate(); if (!ew.with_exception<RocketException>([this](RocketException rex) { sendError(StreamId{0}, std::move(rex)); })) { auto rex = ew ? RocketException(ErrorCode::CONNECTION_ERROR, ew.what()) : RocketException(ErrorCode::CONNECTION_CLOSE, "Closing connection"); sendError(StreamId{0}, std::move(rex)); } state_ = ConnectionState::CLOSING; frameHandler_->connectionClosing(); closeIfNeeded(); } void RocketServerConnection::timeoutExpired() noexcept { DestructorGuard dg(this); if (!isBusy()) { closeWhenIdle(); } } bool RocketServerConnection::isBusy() const { return inflightRequests_ != 0 || !inflightWritesQueue_.empty() || inflightSinkFinalResponses_ != 0 || !writeBatcher_.empty() || activePausedHandlers_ != 0; } // On graceful shutdown, ConnectionManager will first fire the // notifyPendingShutdown() callback for each connection. Then, after the drain // period has elapsed, closeWhenIdle() will be called for each connection. // Note that ConnectionManager waits for a connection to become un-busy before // calling closeWhenIdle(). void RocketServerConnection::notifyPendingShutdown() { startDrain({}); } void RocketServerConnection::startDrain( std::optional<DrainCompleteCode> drainCompleteCode) { if (state_ != ConnectionState::ALIVE) { return; } state_ = ConnectionState::DRAINING; drainCompleteCode_ = drainCompleteCode; sendError(StreamId{0}, RocketException(ErrorCode::CONNECTION_CLOSE)); closeIfNeeded(); } void RocketServerConnection::dropConnection(const std::string& /* errorMsg */) { socketDrainer_.drainComplete(); close(folly::make_exception_wrapper<transport::TTransportException>( transport::TTransportException::TTransportExceptionType::INTERRUPTED, "Dropping connection")); } void RocketServerConnection::closeWhenIdle() { socketDrainer_.drainComplete(); close(folly::make_exception_wrapper<transport::TTransportException>( transport::TTransportException::TTransportExceptionType::INTERRUPTED, "Closing due to imminent shutdown")); } void RocketServerConnection::writeSuccess() noexcept { DestructorGuard dg(this); DCHECK(!inflightWritesQueue_.empty()); auto& context = inflightWritesQueue_.front(); for (auto processingCompleteCount = context.requestCompleteCount; processingCompleteCount > 0; --processingCompleteCount) { frameHandler_->requestComplete(); } for (auto& cb : context.sendCallbacks) { cb.release()->messageSent(); } inflightWritesQueue_.pop_front(); if (onWriteQuiescence_ && writeBatcher_.empty() && inflightWritesQueue_.empty()) { onWriteQuiescence_(ReadPausableHandle(this)); return; } closeIfNeeded(); } void RocketServerConnection::writeErr( size_t /* bytesWritten */, const folly::AsyncSocketException& ex) noexcept { DestructorGuard dg(this); DCHECK(!inflightWritesQueue_.empty()); auto& context = inflightWritesQueue_.front(); for (auto processingCompleteCount = context.requestCompleteCount; processingCompleteCount > 0; --processingCompleteCount) { frameHandler_->requestComplete(); } auto ew = folly::make_exception_wrapper<transport::TTransportException>(ex); for (auto& cb : context.sendCallbacks) { cb.release()->messageSendError(folly::copy(ew)); } inflightWritesQueue_.pop_front(); close(std::move(ew)); } void RocketServerConnection::onEgressBuffered() { const auto buffered = rawSocket_->getAllocatedBytesBuffered(); const auto oldBuffered = egressBufferSize_; egressBufferSize_ = buffered; // track egress memory consumption, drop connection if necessary if (buffered < oldBuffered) { const auto delta = oldBuffered - buffered; egressMemoryTracker_.decrement(delta); DVLOG(10) << fmt::format("buffered: {} (-{}) B", buffered, delta); } else { const auto delta = buffered - oldBuffered; const auto exceeds = !egressMemoryTracker_.increment(delta); DVLOG(10) << fmt::format("buffered: {} (+{}) B", buffered, delta); if (exceeds && rawSocket_->good()) { DestructorGuard dg(this); FB_LOG_EVERY_MS(ERROR, 1000) << fmt::format( "Dropping connection: exceeded egress memory limit ({})", getPeerAddress().describe()); rawSocket_->closeNow(); // triggers writeErr() events now return; } } // pause streams if buffer size reached backpressure threshold if (!egressBufferBackpressureThreshold_) { return; } else if (buffered > egressBufferBackpressureThreshold_ && !streamsPaused_) { pauseStreams(); } else if (streamsPaused_ && buffered < egressBufferRecoverySize_) { resumeStreams(); } } void RocketServerConnection::onEgressBufferCleared() { if (egressBufferSize_) { egressMemoryTracker_.decrement(egressBufferSize_); DVLOG(10) << "buffered: 0 (-" << egressBufferSize_ << ") B"; egressBufferSize_ = 0; } if (UNLIKELY(streamsPaused_)) { resumeStreams(); } } void RocketServerConnection::scheduleStreamTimeout( folly::HHWheelTimer::Callback* timeoutCallback) { if (streamStarvationTimeout_ != std::chrono::milliseconds::zero()) { evb_.timer().scheduleTimeout(timeoutCallback, streamStarvationTimeout_); } } void RocketServerConnection::scheduleSinkTimeout( folly::HHWheelTimer::Callback* timeoutCallback, std::chrono::milliseconds timeout) { if (timeout != std::chrono::milliseconds::zero()) { evb_.timer().scheduleTimeout(timeoutCallback, timeout); } } folly::Optional<Payload> RocketServerConnection::bufferOrGetFullPayload( PayloadFrame&& payloadFrame) { folly::Optional<Payload> fullPayload; const auto streamId = payloadFrame.streamId(); const bool hasFollows = payloadFrame.hasFollows(); const auto it = bufferedFragments_.find(streamId); if (hasFollows) { if (it != bufferedFragments_.end()) { auto& firstFragments = it->second; firstFragments.append(std::move(payloadFrame.payload())); } else { bufferedFragments_.emplace(streamId, std::move(payloadFrame.payload())); } } else { if (it != bufferedFragments_.end()) { auto firstFragments = std::move(it->second); bufferedFragments_.erase(it); firstFragments.append(std::move(payloadFrame.payload())); fullPayload = std::move(firstFragments); } else { fullPayload = std::move(payloadFrame.payload()); } } return fullPayload; } void RocketServerConnection::sendPayload( StreamId streamId, Payload&& payload, Flags flags, apache::thrift::MessageChannel::SendCallbackPtr cb) { send( PayloadFrame(streamId, std::move(payload), flags).serialize(), std::move(cb)); } void RocketServerConnection::sendError( StreamId streamId, RocketException&& rex, apache::thrift::MessageChannel::SendCallbackPtr cb) { send(ErrorFrame(streamId, std::move(rex)).serialize(), std::move(cb)); } void RocketServerConnection::sendRequestN(StreamId streamId, int32_t n) { send(RequestNFrame(streamId, n).serialize()); } void RocketServerConnection::sendCancel(StreamId streamId) { send(CancelFrame(streamId).serialize()); } void RocketServerConnection::sendExt( StreamId streamId, Payload&& payload, Flags flags, ExtFrameType extFrameType) { send(ExtFrame(streamId, std::move(payload), flags, extFrameType).serialize()); } void RocketServerConnection::sendMetadataPush( std::unique_ptr<folly::IOBuf> metadata) { send(MetadataPushFrame::makeFromMetadata(std::move(metadata)).serialize()); } void RocketServerConnection::freeStream( StreamId streamId, bool markRequestComplete) { DestructorGuard dg(this); bufferedFragments_.erase(streamId); DCHECK(streams_.find(streamId) != streams_.end()); streams_.erase(streamId); if (markRequestComplete) { requestComplete(); } } void RocketServerConnection::applyDscpAndMarkToSocket( const RequestSetupMetadata& setupMetadata) { constexpr int32_t kMaxDscpValue = (1 << 6) - 1; if (!socket_) { return; } try { folly::SocketAddress addr; socket_->getAddress(&addr); if (addr.getFamily() != AF_INET6 && addr.getFamily() != AF_INET) { return; } if (auto* sock = socket_->getUnderlyingTransport<folly::AsyncSocket>()) { const auto fd = sock->getNetworkSocket(); if (auto dscp = setupMetadata.dscpToReflect_ref()) { if (auto context = frameHandler_->getCpp2ConnContext()) { THRIFT_CONNECTION_EVENT(rocket.dscp).log(*context, [&] { return folly::dynamic::object("rocket_dscp", *dscp); }); } if (*dscp >= 0 && *dscp <= kMaxDscpValue) { const folly::SocketOptionKey kIpv4TosKey = {IPPROTO_IP, IP_TOS}; const folly::SocketOptionKey kIpv6TosKey = { IPPROTO_IPV6, IPV6_TCLASS}; auto& dscpKey = addr.getIPAddress().isV4() ? kIpv4TosKey : kIpv6TosKey; dscpKey.apply(fd, *dscp << 2); } } #if defined(SO_MARK) if (auto mark = setupMetadata.markToReflect_ref()) { const folly::SocketOptionKey kSoMarkKey = {SOL_SOCKET, SO_MARK}; kSoMarkKey.apply(fd, *mark); } #endif } } catch (const std::exception& ex) { FB_LOG_EVERY_MS(WARNING, 60 * 1000) << "Failed to apply DSCP to socket: " << folly::exceptionStr(ex); } catch (...) { FB_LOG_EVERY_MS(WARNING, 60 * 1000) << "Failed to apply DSCP to socket: " << folly::exceptionStr(std::current_exception()); } } RocketServerConnection::ReadResumableHandle::ReadResumableHandle( RocketServerConnection* connection) : connection_(connection) {} RocketServerConnection::ReadResumableHandle::~ReadResumableHandle() { if (connection_ != nullptr) { std::move(*this).resume(); } } RocketServerConnection::ReadResumableHandle::ReadResumableHandle( ReadResumableHandle&& handle) noexcept : connection_(std::exchange(handle.connection_, nullptr)) {} RocketServerConnection::ReadPausableHandle::ReadPausableHandle( RocketServerConnection* connection) : connection_(connection) { ++connection_->activePausedHandlers_; } void RocketServerConnection::ReadResumableHandle::resume() && { DCHECK(connection_ != nullptr) << "resume() has been called on this handle"; --connection_->activePausedHandlers_; if (connection_->state_ == ConnectionState::ALIVE || connection_->state_ == ConnectionState::DRAINING) { connection_->socket_->setReadCB(&connection_->parser_); } connection_->closeIfNeeded(); connection_ = nullptr; } RocketServerConnection::ReadPausableHandle::~ReadPausableHandle() { if (connection_ != nullptr) { --connection_->activePausedHandlers_; connection_->closeIfNeeded(); } } RocketServerConnection::ReadPausableHandle::ReadPausableHandle( ReadPausableHandle&& handle) noexcept : connection_(std::exchange(handle.connection_, nullptr)) {} RocketServerConnection::ReadResumableHandle RocketServerConnection::ReadPausableHandle::pause() && { DCHECK(connection_ != nullptr) << "pause() has been called on this handle"; connection_->socket_->setReadCB(nullptr); return ReadResumableHandle(std::exchange(connection_, nullptr)); } void RocketServerConnection::pauseStreams() { DCHECK(!streamsPaused_); streamsPaused_ = true; for (auto it = streams_.begin(); it != streams_.end(); it++) { folly::variant_match( it->second, [](const std::unique_ptr<RocketStreamClientCallback>& stream) { stream->pauseStream(); }, [](const auto&) {}); } } void RocketServerConnection::resumeStreams() { DCHECK(streamsPaused_); streamsPaused_ = false; for (auto it = streams_.begin(); it != streams_.end(); it++) { folly::variant_match( it->second, [](const std::unique_ptr<RocketStreamClientCallback>& stream) { stream->resumeStream(); }, [](const auto&) {}); } } } // namespace rocket } // namespace thrift } // namespace apache