cpp-channel/cpp/HsChannel.h (209 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. #pragma once #include <folly/io/IOBuf.h> #include <thrift/lib/cpp/concurrency/Thread.h> #include <thrift/lib/cpp/protocol/TProtocolTypes.h> #include <thrift/lib/cpp2/async/RequestChannel.h> #include <HsFFI.h> enum Status { SEND_ERROR, SEND_SUCCESS, RECV_ERROR, RECV_SUCCESS }; struct FinishedRequest { const uint8_t* buffer; size_t len; Status status; }; using InnerChannel = std::unique_ptr< apache::thrift::RequestChannel, folly::DelayedDestruction::Destructor>; class HsCallback : public apache::thrift::RequestClientCallback { public: explicit HsCallback( std::shared_ptr<InnerChannel> client, int cap, HsStablePtr send_mvar, HsStablePtr recv_mvar, FinishedRequest* send_result, FinishedRequest* recv_result) : client_(std::move(client)), cap_(cap), send_mvar_(send_mvar), recv_mvar_(recv_mvar), send_result_(send_result), recv_result_(recv_result) {} bool isInlineSafe() const override { // our callbacks do memcpy/malloc/hs_try_putmvar (which is nonblocking and // very quick). // this should be inline safe return true; } // Note [onResponse leak] // // The memory containing the result is transferred from C++ to // Haskell in onResponseError / onResponse. hs_try_putmvar() wakes // up the Haskell thread running CppChannel.hsc:sendCollector // or CppChannel.hsc:recvCollector, which takes ownership of the // memory in a ForeignPtr. // // This is carefully designed to not leak even if the thread making // the original Thrift request is // interrupted. sendCollector/recvCollector are running in a // separate thread which will run and take ownership of the memory // even if the original thread that made the request has gone away. // // However, if the program exits after onResponse/onResponseError // but before the Haskell thread running sendCollector/recvCollector // runs, this memory may be detected as a leak by leak-checkers such // as ASAN. It's not really a leak, just an artifact of exiting at // the wrong time. // void onResponseError(folly::exception_wrapper ew) noexcept override { auto ex = ew.what(); size_t len = ex.length(); auto buf = std::unique_ptr<uint8_t, decltype(free)*>{ reinterpret_cast<uint8_t*>(malloc(len * sizeof(uint8_t))), free}; // If you get a memory leak here, see Note [onResponse leak]. std::memcpy(buf.get(), ex.data(), len); bool sendError = false; ew.with_exception( [&](apache::thrift::transport::TTransportException const& tex) { sendError = tex.getType() == apache::thrift::transport::TTransportException::NOT_OPEN; }); if (sendError || !recv_result_) { send_result_->status = SEND_ERROR; send_result_->buffer = buf.release(); send_result_->len = len; hs_try_putmvar(cap_, send_mvar_); } else { requestSentHelper(); recv_result_->status = RECV_ERROR; recv_result_->buffer = buf.release(); recv_result_->len = len; hs_try_putmvar(cap_, recv_mvar_); } delete this; } void onResponse( apache::thrift::ClientReceiveState&& state) noexcept override { SCOPE_EXIT { delete this; }; requestSentHelper(); if (!recv_result_) { return; } if (state.isException()) { auto ex = state.exception().what(); size_t len = ex.length(); auto buf = std::unique_ptr<uint8_t, decltype(free)*>{ reinterpret_cast<uint8_t*>(malloc(len * sizeof(uint8_t))), free}; // If you get a memory leak here, see Note [onResponse leak]. std::memcpy(buf.get(), ex.data(), len); recv_result_->status = RECV_ERROR; recv_result_->buffer = buf.release(); recv_result_->len = len; hs_try_putmvar(cap_, recv_mvar_); } else { auto ioBuf = apache::thrift::LegacySerializedResponse( state.protocolId(), 0, state.messageType(), methodName_, state.extractSerializedResponse()) .buffer; size_t len = ioBuf->computeChainDataLength(); auto msg = std::unique_ptr<uint8_t, decltype(free)*>{ reinterpret_cast<uint8_t*>(malloc(len * sizeof(uint8_t))), free}; auto pos = msg.get(); for (auto r : *ioBuf) { std::memcpy(pos, r.data(), r.size()); pos += r.size(); } recv_result_->status = RECV_SUCCESS; recv_result_->buffer = msg.release(); recv_result_->len = len; hs_try_putmvar(cap_, recv_mvar_); } } void setMethodName(std::string name) { methodName_ = std::move(name); } std::string const& getMethodName() const { return methodName_; } void setContextStack( std::unique_ptr<apache::thrift::ContextStack>&& contextStack) { contextStack_ = std::move(contextStack); } private: void requestSentHelper() { send_result_->status = SEND_SUCCESS; hs_try_putmvar(cap_, send_mvar_); } std::shared_ptr<InnerChannel> client_; // see Note [channel lifetime] int cap_; HsStablePtr send_mvar_; HsStablePtr recv_mvar_; FinishedRequest* send_result_; FinishedRequest* recv_result_; std::string methodName_; // Note that the contextStack_ *need* to be declared after // methodName_. ContextStack contains a pointer reference // to the method-name, and needs to access this pointer in its // destructor. Fields are destructed in reverse order, and we need // methodName_ to have a larger lifetime than the ContextStack // object. std::unique_ptr<apache::thrift::ContextStack> contextStack_; }; using CallbackPtr = std::unique_ptr< HsCallback, apache::thrift::RequestClientCallback::RequestClientCallbackDeleter>; /* Note [channel lifetime] * * The ChannelWrapper implementation keeps the InnerChannel alive * until all the outstanding HsCallbacks have been called. This is to * support use cases that need to receive the responses to requests * outside of the scope of the channel creation * (e.g. withHeaderChannel). An example is the Haxl datasource for * Thrift services, which does not have a way to scope * withHeaderChannel over the lifetime of the requests. Without this * feature, the channel is closed on exit from the scope of * withHeaderChannel, and the outstanding requests will fail. * * The alternative to keeping the InnerChannel alive here would be to * make the client do its own reference counting, which is hard and * error-prone. * * For a test case see thrift/cpp-channel/tests/LifetimeTest.hs */ class ChannelWrapper : public apache::thrift::TClientBase { public: explicit ChannelWrapper(InnerChannel client) : client_(std::make_shared<InnerChannel>(std::move(client))) {} ~ChannelWrapper() { auto evb = client_->get()->getEventBase(); // Only run the destructor once all callbacks are done auto destroyClient = [client = std::move(client_)]() mutable {}; // Move the unique_ptr into the lambda so that it gets destructed in the // eventbase thread if (evb) { evb->runInEventBaseThread(std::move(destroyClient)); } } apache::thrift::protocol::PROTOCOL_TYPES getProtocolType(uint8_t buf) { // Check hex value to see which protocol is being used return buf == 0x82 ? apache::thrift::protocol::PROTOCOL_TYPES::T_COMPACT_PROTOCOL : apache::thrift::protocol::PROTOCOL_TYPES::T_BINARY_PROTOCOL; } void sendRequest( uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, HsStablePtr recv_mvar, FinishedRequest* send_result, FinishedRequest* recv_result, apache::thrift::RpcOptions&& rpcOpts); void sendOnewayRequest( uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, FinishedRequest* send_result, apache::thrift::RpcOptions&& rpcOpts); InnerChannel* getInnerRequestChannel() { return client_.get(); } private: enum RequestDirection { WITH_RESPONSE = 1, NO_RESPONSE = 2, }; void sendRequestImpl( RequestDirection direction, apache::thrift::protocol::PROTOCOL_TYPES protocolId, CallbackPtr&& callback, std::unique_ptr<folly::IOBuf>&& message, apache::thrift::RpcOptions&& rpcOptions); template <typename F> void runOnClientEvbIfAvailable(F&& f) { if (auto evb = client_->get()->getEventBase()) { evb->add(std::forward<F>(f)); } else { f(); } } std::shared_ptr<InnerChannel> client_; }; extern "C" { std::shared_ptr<ChannelWrapper>* newWrapper(InnerChannel* channel) noexcept; void deleteWrapper(std::shared_ptr<ChannelWrapper>* channel) noexcept; void sendReq( std::shared_ptr<ChannelWrapper>* client, uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, HsStablePtr recv_mvar, FinishedRequest* send_result, FinishedRequest* recv_result, uint8_t* rpcOptionsPtr, size_t rpcOptionsLen) noexcept; void sendOnewayReq( std::shared_ptr<ChannelWrapper>* client, uint8_t* buf, size_t len, int capability, HsStablePtr send_mvar, FinishedRequest* send_result, uint8_t* rpcOptionsPtr, size_t rpcOptionsLen) noexcept; InnerChannel* getInnerRequestChannel( std::shared_ptr<ChannelWrapper>* client) noexcept; }