cpp/acs_agent_client_reactor.cc (219 lines of code) (raw):

#include "cpp/acs_agent_client_reactor.h" #include <unistd.h> #include <cstdint> #include <map> #include <memory> #include <string> #include <utility> #include "proto/agent_communication.grpc.pb.h" #include "absl/functional/any_invocable.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "cpp/acs_agent_helper.h" #include "grpc/grpc.h" #include "grpcpp/security/credentials.h" #include "grpcpp/support/channel_arguments.h" #include "grpcpp/support/status.h" #include "grpcpp/support/string_ref.h" namespace agent_communication { // Aliases of the protobuf message types used in the ACS Agent Communication // service in a .cc file. using Response = ::google::cloud::agentcommunication::v1::StreamAgentMessagesResponse; using Request = ::google::cloud::agentcommunication::v1::StreamAgentMessagesRequest; // Alias of the stub type used in the ACS Agent Communication service in a .cc // file. using AcsStub = ::google::cloud::agentcommunication::v1::AgentCommunication::Stub; AcsAgentClientReactor::AcsAgentClientReactor( std::unique_ptr<AcsStub> stub, absl::AnyInvocable<void(Response, RpcStatus)> read_callback) : stub_(std::move(stub)), read_callback_(std::move(read_callback)) { stub_->async()->StreamAgentMessages(&context_, this); StartRead(&response_); StartCall(); } AcsAgentClientReactor::AcsAgentClientReactor( std::unique_ptr<AcsStub> stub, absl::AnyInvocable<void(Response, RpcStatus)> read_callback, const AgentConnectionId& agent_connection_id) : stub_(std::move(stub)), read_callback_(std::move(read_callback)) { context_.AddMetadata("authentication", "Bearer " + agent_connection_id.token); context_.AddMetadata("agent-communication-resource-id", agent_connection_id.resource_id); context_.AddMetadata("agent-communication-channel-id", agent_connection_id.channel_id); stub_->async()->StreamAgentMessages(&context_, this); StartRead(&response_); StartCall(); } std::unique_ptr<AcsStub> AcsAgentClientReactor::CreateStub( const std::string& endpoint) { grpc::SslCredentialsOptions options; grpc::ChannelArguments channel_args; // Keepalive settings channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 60 * 1000); // 60 seconds channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 10 * 1000); // 10 seconds return google::cloud::agentcommunication::v1::AgentCommunication::NewStub( grpc::CreateCustomChannel( endpoint, grpc::SslCredentials(options), channel_args)); } absl::StatusOr<uint64_t> AcsAgentClientReactor::GetMessagesPerMinuteQuota() { absl::MutexLock lock(&status_mtx_); return messages_per_minute_quota_; } absl::StatusOr<uint64_t> AcsAgentClientReactor::GetBytesPerMinuteQuota() { absl::MutexLock lock(&status_mtx_); return bytes_per_minute_quota_; } bool AcsAgentClientReactor::Cancel() { absl::MutexLock lock(&status_mtx_); if (rpc_done_) { ABSL_VLOG(1) << "The RPC has already been terminated when attempting to cancel."; return false; } rpc_cancelled_by_client_ = true; context_.TryCancel(); return true; } AcsAgentClientReactor::~AcsAgentClientReactor() { Cancel(); grpc::Status status = Await(); ABSL_VLOG(1) << absl::StrFormat( "AcsAgentClientReactor is destroyed with termination status code: %d " "and message: %s and details: %s", status.error_code(), status.error_message(), status.error_details()); } void AcsAgentClientReactor::OnWriteDone(bool ok) { if (!ok) { ABSL_VLOG(1) << "OnWriteDone not ok"; return; } absl::MutexLock lock(&request_mtx_); writing_ = false; if (request_->has_message_response()) { // Pop the queue of ack_buffer_ if the last write was an ack. ABSL_VLOG(1) << "Successfully write on the stream with ack with id: " << request_->message_id(); ack_buffer_.pop(); } else { // Clear up the msg_request_ if the last write was a message. ABSL_VLOG(1) << "Successfully write on the stream with message with id: " << request_->message_id(); msg_request_ = nullptr; } NextWrite(); } void AcsAgentClientReactor::Ack(std::string message_id) { absl::MutexLock lock(&request_mtx_); std::unique_ptr<Request> request = MakeAck(std::move(message_id)); ack_buffer_.push(std::move(request)); if (!writing_) { NextWrite(); } } void AcsAgentClientReactor::OnReadDone(bool ok) { if (!ok) { ABSL_VLOG(1) << "OnReadDone not ok"; { absl::MutexLock lock(&status_mtx_); if (rpc_cancelled_by_client_) { read_callback_(Response(), RpcStatus::kRpcClosedByClient); return; } } read_callback_(Response(), RpcStatus::kRpcClosedByServer); return; } if (response_.has_message_body()) { ABSL_VLOG(1) << "Client Ack on message with id: " << response_.message_id(); Ack(response_.message_id()); } read_callback_(std::move(response_), RpcStatus::kRpcOk); StartRead(&response_); } void AcsAgentClientReactor::OnReadInitialMetadataDone(bool ok) { if (!ok) { ABSL_LOG(WARNING) << "OnReadInitialMetadataDone not ok"; return; } const std::multimap<grpc::string_ref, grpc::string_ref>& metadata = context_.GetServerInitialMetadata(); absl::MutexLock lock(&status_mtx_); messages_per_minute_quota_ = GetIntValueFromInitialMetadata<uint64_t>( metadata, kMessagesPerMinuteQuotaKey); bytes_per_minute_quota_ = GetIntValueFromInitialMetadata<uint64_t>( metadata, kBytesPerMinuteQuotaKey); } template <typename T> absl::StatusOr<T> AcsAgentClientReactor::GetIntValueFromInitialMetadata( const std::multimap<grpc::string_ref, grpc::string_ref>& metadata, const std::string& key) { if (metadata.count(key) == 0) { return absl::NotFoundError(absl::StrCat( "The key: ", key, " was not found in initial metadata from server.")); } auto [range_begin, range_end] = metadata.equal_range(key); // Theoretically, there should be <=1 value for the quota keys in the initial // metadata of server. If there are multiple values, we will return the first // valid value. for (auto it = range_begin; it != range_end; ++it) { T value = 0; std::string value_str(it->second.data(), it->second.size()); if (!absl::SimpleAtoi(value_str, &value)) { ABSL_LOG(WARNING) << "key: " << key << " found in initial metadata from server but its " "value having the wrong format: " << value_str; } else { return value; } } return absl::NotFoundError( absl::StrCat("key: ", key, " was found in initial metadata from server but its value " "was not a valid integer.")); } void AcsAgentClientReactor::OnDone(const ::grpc::Status& status) { absl::MutexLock lock(&status_mtx_); ABSL_VLOG(1) << absl::StrFormat( "RPC terminated with status code: %d and message: %s and details: %s", status.error_code(), status.error_message(), status.error_details()); rpc_final_status_ = status; rpc_done_ = true; } grpc::Status AcsAgentClientReactor::Await() { status_mtx_.LockWhen( absl::Condition(+[](bool* done) { return *done; }, &rpc_done_)); grpc::Status status = rpc_final_status_; status_mtx_.Unlock(); return status; } bool AcsAgentClientReactor::AddRequest(const Request& request) { absl::MutexLock lock(&request_mtx_); if (msg_request_ == nullptr) { // Add the new request to the buffer of reactor, as the last msg_request_ // was completed. msg_request_ = std::make_unique<Request>(request); if (!writing_) { NextWrite(); } return true; } else { // Return false as the last msg_request_ was not completed. ABSL_VLOG(1) << absl::StrFormat( "Failed to add request of id: %s to the buffer of reactor. The last " "request of id: %s is not written to the stream yet.", msg_request_->message_id(), msg_request_->message_id()); return false; } } void AcsAgentClientReactor::NextWrite() { if (msg_request_ == nullptr && ack_buffer_.empty()) { return; } writing_ = true; if (!ack_buffer_.empty()) { // Prioritize the send of ack over message. request_ = ack_buffer_.front().get(); } else { request_ = msg_request_.get(); } StartWrite(request_); } } // namespace agent_communication