thrift/lib/cpp2/async/FutureRequest.h (317 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include <folly/CancellationToken.h> #include <folly/futures/Future.h> #include <thrift/lib/cpp2/async/RequestChannel.h> namespace apache { namespace thrift { template <typename Result> class FutureCallbackBase : public RequestCallback { public: explicit FutureCallbackBase( folly::Promise<Result>&& promise, std::shared_ptr<apache::thrift::RequestChannel> channel = nullptr) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override {} void requestError(ClientReceiveState&& state) override { CHECK(state.isException()); promise_.setException(std::move(state.exception())); } protected: folly::Promise<Result> promise_; std::shared_ptr<apache::thrift::RequestChannel> channel_; }; template <typename Result> class FutureCallback : public FutureCallbackBase<Result> { private: typedef folly::exception_wrapper (*Processor)(Result&, ClientReceiveState&); public: FutureCallback( folly::Promise<Result>&& promise, Processor processor, std::shared_ptr<apache::thrift::RequestChannel> channel = nullptr) : FutureCallbackBase<Result>(std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor_(result, state); if (ew) { this->promise_.setException(ew); } else { this->promise_.setValue(std::move(result)); } } private: Processor processor_; }; template <typename Result> class HeaderFutureCallback : public FutureCallbackBase<std::pair< Result, std::unique_ptr<apache::thrift::transport::THeader>>> { private: using HeaderResult = std::pair<Result, std::unique_ptr<apache::thrift::transport::THeader>>; typedef folly::exception_wrapper (*Processor)(Result&, ClientReceiveState&); Processor processor_; public: HeaderFutureCallback( folly::Promise<HeaderResult>&& promise, Processor processor, std::shared_ptr<apache::thrift::RequestChannel> channel = nullptr) : FutureCallbackBase<HeaderResult>( std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor_(result, state); if (ew) { this->promise_.setException(ew); } else { this->promise_.setValue( std::make_pair(std::move(result), state.extractHeader())); } } }; template <> class HeaderFutureCallback<folly::Unit> : public FutureCallbackBase<std::pair< folly::Unit, std::unique_ptr<apache::thrift::transport::THeader>>> { private: using HeaderResult = std:: pair<folly::Unit, std::unique_ptr<apache::thrift::transport::THeader>>; typedef folly::exception_wrapper (*Processor)(ClientReceiveState&); Processor processor_; public: HeaderFutureCallback( folly::Promise<HeaderResult>&& promise, Processor processor, std::shared_ptr<apache::thrift::RequestChannel> channel = nullptr) : FutureCallbackBase<HeaderResult>( std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor_(state); if (ew) { promise_.setException(ew); } else { promise_.setValue(std::make_pair(folly::Unit(), state.extractHeader())); } } }; class OneWayFutureCallback : public FutureCallbackBase<folly::Unit> { public: explicit OneWayFutureCallback( folly::Promise<folly::Unit>&& promise, std::shared_ptr<apache::thrift::RequestChannel> channel = nullptr) : FutureCallbackBase<folly::Unit>( std::move(promise), std::move(channel)) {} void requestSent() override { promise_.setValue(); } void replyReceived(ClientReceiveState&& /*state*/) override { CHECK(false); } }; template <> class FutureCallback<folly::Unit> : public FutureCallbackBase<folly::Unit> { private: typedef folly::exception_wrapper (*Processor)(ClientReceiveState&); public: FutureCallback( folly::Promise<folly::Unit>&& promise, Processor processor, std::shared_ptr<apache::thrift::RequestChannel> channel = nullptr) : FutureCallbackBase<folly::Unit>(std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor_(state); if (ew) { promise_.setException(ew); } else { promise_.setValue(); } } private: Processor processor_; }; class SemiFutureCallback : public RequestCallback { public: template <typename Result> using Processor = folly::exception_wrapper (*)(Result&, ClientReceiveState&); using ProcessorVoid = folly::exception_wrapper (*)(ClientReceiveState&); explicit SemiFutureCallback( folly::Promise<ClientReceiveState>&& promise, std::shared_ptr<apache::thrift::RequestChannel> channel) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override {} void replyReceived(ClientReceiveState&& state) override { promise_.setValue(std::move(state)); } void requestError(ClientReceiveState&& state) override { promise_.setException(std::move(state.exception())); } bool isInlineSafe() const override { return true; } protected: folly::Promise<ClientReceiveState> promise_; std::shared_ptr<apache::thrift::RequestChannel> channel_; }; class OneWaySemiFutureCallback : public RequestCallback { public: OneWaySemiFutureCallback( folly::Promise<folly::Unit>&& promise, std::shared_ptr<apache::thrift::RequestChannel> channel) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override { promise_.setValue(); } void replyReceived(ClientReceiveState&&) override { CHECK(false); } void requestError(ClientReceiveState&& state) override { promise_.setException(std::move(state.exception())); } bool isInlineSafe() const override { return true; } protected: folly::Promise<folly::Unit> promise_; std::shared_ptr<apache::thrift::RequestChannel> channel_; }; template <typename Result> std::pair<std::unique_ptr<SemiFutureCallback>, folly::SemiFuture<Result>> makeSemiFutureCallback( SemiFutureCallback::Processor<Result> processor, std::shared_ptr<apache::thrift::RequestChannel> channel) { folly::Promise<ClientReceiveState> promise; auto future = promise.getSemiFuture(); return { std::make_unique<SemiFutureCallback>( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor(result, state); if (ew) { ew.throw_exception(); } return result; })}; } inline std:: pair<std::unique_ptr<SemiFutureCallback>, folly::SemiFuture<folly::Unit>> makeSemiFutureCallback( SemiFutureCallback::ProcessorVoid processor, std::shared_ptr<apache::thrift::RequestChannel> channel) { folly::Promise<ClientReceiveState> promise; auto future = promise.getSemiFuture(); return { std::make_unique<SemiFutureCallback>( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor(state); if (ew) { ew.throw_exception(); } })}; } template <typename Result> std::pair< std::unique_ptr<SemiFutureCallback>, folly::SemiFuture< std::pair<Result, std::unique_ptr<apache::thrift::transport::THeader>>>> makeHeaderSemiFutureCallback( SemiFutureCallback::Processor<Result> processor, std::shared_ptr<apache::thrift::RequestChannel> channel) { folly::Promise<ClientReceiveState> promise; auto future = promise.getSemiFuture(); return { std::make_unique<SemiFutureCallback>( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor(result, state); if (ew) { ew.throw_exception(); } return std::make_pair(std::move(result), state.extractHeader()); })}; } inline std::pair< std::unique_ptr<SemiFutureCallback>, folly::SemiFuture<std::pair< folly::Unit, std::unique_ptr<apache::thrift::transport::THeader>>>> makeHeaderSemiFutureCallback( SemiFutureCallback::ProcessorVoid processor, std::shared_ptr<apache::thrift::RequestChannel> channel) { folly::Promise<ClientReceiveState> promise; auto future = promise.getSemiFuture(); return { std::make_unique<SemiFutureCallback>( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor(state); if (ew) { ew.throw_exception(); } return std::make_pair(folly::unit, state.extractHeader()); })}; } inline std::pair< std::unique_ptr<OneWaySemiFutureCallback>, folly::SemiFuture<folly::Unit>> makeOneWaySemiFutureCallback( std::shared_ptr<apache::thrift::RequestChannel> channel) { folly::Promise<folly::Unit> promise; auto future = promise.getSemiFuture(); return { std::make_unique<OneWaySemiFutureCallback>( std::move(promise), std::move(channel)), std::move(future)}; } template <bool oneWay> class CancellableRequestClientCallback : public RequestClientCallback { CancellableRequestClientCallback( RequestClientCallback* wrapped, std::shared_ptr<RequestChannel> channel) : callback_(wrapped), channel_(std::move(channel)) { DCHECK(wrapped->isInlineSafe()); } public: static std::unique_ptr<CancellableRequestClientCallback> create( RequestClientCallback* wrapped, std::shared_ptr<RequestChannel> channel) { return std::unique_ptr<CancellableRequestClientCallback>( new CancellableRequestClientCallback(wrapped, std::move(channel))); } static void cancel(std::unique_ptr<CancellableRequestClientCallback> cb) { cb.release()->onResponseError( folly::make_exception_wrapper<folly::OperationCancelled>()); } void onResponse(ClientReceiveState&& state) noexcept override { if (auto callback = callback_.exchange(nullptr, std::memory_order_acq_rel)) { callback->onResponse(std::move(state)); } else { delete this; } } void onResponseError(folly::exception_wrapper ew) noexcept override { if (auto callback = callback_.exchange(nullptr, std::memory_order_acq_rel)) { callback->onResponseError(std::move(ew)); } else { delete this; } } bool isInlineSafe() const override { return true; } private: std::atomic<RequestClientCallback*> callback_; std::shared_ptr<RequestChannel> channel_; }; } // namespace thrift } // namespace apache