cpp-channel/cpp/HsChannel.cpp (188 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. #include <cpp/HsChannel.h> #include <if/gen-cpp2/RpcOptions_types.h> #include <thrift/lib/cpp/ContextStack.h> #include <thrift/lib/cpp2/protocol/Serializer.h> using namespace thrift::protocol; using namespace apache::thrift::concurrency; std::shared_ptr<ChannelWrapper>* newWrapper(InnerChannel* channel) noexcept { return new std::shared_ptr<ChannelWrapper>( std::make_shared<ChannelWrapper>(std::move(*channel))); } void deleteWrapper(std::shared_ptr<ChannelWrapper>* channel) noexcept { delete channel; } void ChannelWrapper::sendRequest( uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, HsStablePtr recv_mvar, FinishedRequest* send_result, FinishedRequest* recv_result, apache::thrift::RpcOptions&& rpcOpts) { auto msg = folly::IOBuf::wrapBuffer(buf, len); auto cob = CallbackPtr(new HsCallback( client_, capability, send_mvar, recv_mvar, send_result, recv_result)); sendRequestImpl( ChannelWrapper::RequestDirection::WITH_RESPONSE, getProtocolType(buf[0]), std::move(cob), std::move(msg), std::move(rpcOpts)); } void ChannelWrapper::sendOnewayRequest( uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, FinishedRequest* send_result, apache::thrift::RpcOptions&& rpcOpts) { auto msg = folly::IOBuf::wrapBuffer(buf, len); auto cob = CallbackPtr(new HsCallback( client_, capability, send_mvar, nullptr, send_result, nullptr)); sendRequestImpl( ChannelWrapper::RequestDirection::NO_RESPONSE, getProtocolType(buf[0]), std::move(cob), std::move(msg), std::move(rpcOpts)); } void ChannelWrapper::sendRequestImpl( ChannelWrapper::RequestDirection requestDirection, apache::thrift::protocol::PROTOCOL_TYPES protocolId, CallbackPtr&& callback, std::unique_ptr<folly::IOBuf>&& message, apache::thrift::RpcOptions&& rpcOptions) { auto header = std::make_shared<apache::thrift::transport::THeader>(0); header->setProtocolId(protocolId); header->setHeaders(rpcOptions.releaseWriteHeaders()); auto envelopeAndRequest = apache::thrift::EnvelopeUtil::stripRequestEnvelope(std::move(message)); if (!envelopeAndRequest.has_value()) { callback.release()->onResponseError( folly::make_exception_wrapper< apache::thrift::transport::TTransportException>( apache::thrift::transport::TTransportException::CORRUPTED_DATA, "Unexpected problem stripping envelope")); return; } auto envelope = std::move(envelopeAndRequest->first); callback->setMethodName(envelope.methodName); // Create a new context-stack for the request, which will be used to trigger // the appropriate thrift middleware to run on the request in itself (e.g. // ContextProp). // // Note that we need to be very careful about the lifetime of the object // and everything it does reference, as this can cause issues with memory // leaks. // // This is why we're directly referencing shared-pointers as well as // preserving the lifetime of the stack, and the method-name, on the callback // object itself. auto contextStack = apache::thrift::ContextStack::createWithClientContext( handlers_, "" /* service name */, callback->getMethodName().c_str(), *header); if (contextStack) { contextStack->preWrite(); } auto request = apache::thrift::SerializedRequest(std::move(envelopeAndRequest->second)); if (contextStack) { apache::thrift::SerializedMessage serializedMessage; serializedMessage.protocolType = envelope.protocolId; serializedMessage.buffer = request.buffer.get(); serializedMessage.methodName = envelope.methodName; contextStack->onWriteData(serializedMessage); contextStack->postWrite( folly::to_narrow(request.buffer->computeChainDataLength())); contextStack->resetClientRequestContextHeader(); } // Transfer ownership of the context-stack to the callback in order to // preserve lifetime throughout the request callback->setContextStack(std::move(contextStack)); runOnClientEvbIfAvailable([client = client_, requestDirection = requestDirection, rpcOptions = std::move(rpcOptions), request = std::move(request), header = std::move(header), envelope = std::move(envelope), callback = std::move(callback)]() mutable { switch (requestDirection) { case ChannelWrapper::RequestDirection::WITH_RESPONSE: client->get()->sendRequestResponse( std::move(rpcOptions), envelope.methodName, std::move(request), std::move(header), std::move(callback)); break; case ChannelWrapper::RequestDirection::NO_RESPONSE: client->get()->sendRequestNoResponse( std::move(rpcOptions), envelope.methodName, std::move(request), std::move(header), std::move(callback)); break; } }); } static_assert((int)Priority::HighImportant == (int)HIGH_IMPORTANT); static_assert((int)Priority::High == (int)HIGH); static_assert((int)Priority::Important == (int)IMPORTANT); static_assert((int)Priority::NormalPriority == (int)NORMAL); static_assert((int)Priority::BestEffort == (int)BEST_EFFORT); apache::thrift::RpcOptions getRpcOptions( uint8_t* rpcOptionsPtr, size_t rpcOptionsLen) noexcept { apache::thrift::RpcOptions rpcOpts; auto tRpcOpts = apache::thrift::BinarySerializer::deserialize< thrift::protocol::RpcOptions>( folly::ByteRange(rpcOptionsPtr, rpcOptionsLen)); rpcOpts.setTimeout(std::chrono::milliseconds(*tRpcOpts.timeout_ref())); auto priority = tRpcOpts.get_priority() == nullptr ? apache::thrift::RpcOptions::PRIORITY::NORMAL : static_cast<apache::thrift::RpcOptions::PRIORITY>( tRpcOpts.priority_ref().value_unchecked()); rpcOpts.setPriority(priority); rpcOpts.setChunkTimeout( std::chrono::milliseconds(*tRpcOpts.chunkTimeout_ref())); rpcOpts.setQueueTimeout( std::chrono::milliseconds(*tRpcOpts.queueTimeout_ref())); if (tRpcOpts.get_headers() != nullptr) { for (auto const& header : tRpcOpts.headers_ref().value_unchecked()) { rpcOpts.setWriteHeader(header.first, header.second); } } return rpcOpts; } void sendReq( std::shared_ptr<ChannelWrapper>* client, uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, HsStablePtr recv_mvar, FinishedRequest* send_result, FinishedRequest* recv_result, uint8_t* rpcOptionsPtr, size_t rpcOptionsLen) noexcept { apache::thrift::RpcOptions rpcOpts = getRpcOptions(rpcOptionsPtr, rpcOptionsLen); (*client)->sendRequest( buf, len, capability, send_mvar, recv_mvar, send_result, recv_result, std::move(rpcOpts)); } void sendOnewayReq( std::shared_ptr<ChannelWrapper>* client, uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, FinishedRequest* send_result, uint8_t* rpcOptionsPtr, size_t rpcOptionsLen) noexcept { apache::thrift::RpcOptions rpcOpts = getRpcOptions(rpcOptionsPtr, rpcOptionsLen); (*client)->sendOnewayRequest( buf, len, capability, send_mvar, send_result, std::move(rpcOpts)); } InnerChannel* getInnerRequestChannel( std::shared_ptr<ChannelWrapper>* client) noexcept { return (*client)->getInnerRequestChannel(); }