cpp/acs_agent_client.cc (436 lines of code) (raw):

#include "cpp/acs_agent_client.h" #include <algorithm> #include <chrono> #include <cstdint> #include <future> #include <limits> #include <memory> #include <queue> #include <string> #include <thread> #include <unordered_map> #include <utility> #include "proto/agent_communication.grpc.pb.h" #include "absl/functional/any_invocable.h" #include "absl/functional/bind_front.h" #include "absl/log/absl_log.h" #include "absl/memory/memory.h" #include "absl/random/distributions.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "cpp/acs_agent_client_reactor.h" #include "cpp/acs_agent_helper.h" #include "grpcpp/support/status.h" namespace agent_communication { // Alias of the protobuf message types used in the ACS Agent Communication // service in a .cc file. using Response = ::google::cloud::agentcommunication::v1::StreamAgentMessagesResponse; using Request = ::google::cloud::agentcommunication::v1::StreamAgentMessagesRequest; using MessageBody = ::google::cloud::agentcommunication::v1::MessageBody; using AcsStub = ::google::cloud::agentcommunication::v1::AgentCommunication::Stub; absl::StatusOr<std::unique_ptr<AcsAgentClient>> AcsAgentClient::Create( bool endpoint_regional, std::string channel_id, absl::AnyInvocable<void( google::cloud::agentcommunication::v1::StreamAgentMessagesResponse)> read_callback, std::chrono::milliseconds max_wait_time_for_ack) { // Generate the connection id. absl::StatusOr<AgentConnectionId> agent_connection_id = GenerateAgentConnectionId(std::move(channel_id), endpoint_regional); if (!agent_connection_id.ok()) { return agent_connection_id.status(); } // Create the stub. std::unique_ptr<AcsStub> stub = AcsAgentClientReactor::CreateStub(agent_connection_id->endpoint); return Create(std::move(stub), *std::move(agent_connection_id), std::move(read_callback), nullptr, nullptr, max_wait_time_for_ack); } absl::StatusOr<std::unique_ptr<AcsAgentClient>> AcsAgentClient::Create( std::unique_ptr< google::cloud::agentcommunication::v1::AgentCommunication::Stub> stub, AgentConnectionId agent_connection_id, absl::AnyInvocable<void( google::cloud::agentcommunication::v1::StreamAgentMessagesResponse)> read_callback, absl::AnyInvocable<std::unique_ptr<AcsStub>()> stub_generator, absl::AnyInvocable<absl::StatusOr<AgentConnectionId>()> connection_id_generator, std::chrono::milliseconds max_wait_time_for_ack) { // Create the client. std::unique_ptr<AcsAgentClient> client = absl::WrapUnique(new AcsAgentClient( agent_connection_id, std::move(read_callback), std::move(stub_generator), std::move(connection_id_generator), max_wait_time_for_ack)); // Start the read message thread. std::thread read_message_body_thread( absl::bind_front(&AcsAgentClient::ClientReadMessage, client.get())); client->read_response_thread_ = std::move(read_message_body_thread); // Initialize the client. { absl::MutexLock lock(&client->reactor_mtx_); client->reactor_ = std::make_unique<AcsAgentClientReactor>( std::move(stub), absl::bind_front(&AcsAgentClient::ReactorReadCallback, client.get()), std::move(agent_connection_id)); absl::Status init_status = client->Init(); if (!init_status.ok()) { return init_status; } } // Start the restart client thread after the Init() was successful. std::thread restart_client_thread( absl::bind_front(&AcsAgentClient::RestartClient, client.get())); client->restart_client_thread_ = std::move(restart_client_thread); return client; } absl::StatusOr<uint64_t> AcsAgentClient::GetMessagePerMinuteQuota() { absl::MutexLock lock(&reactor_mtx_); if (reactor_ == nullptr || stream_state_ != ClientState::kReady) { return absl::FailedPreconditionError("stream not initialized."); } return reactor_->GetMessagesPerMinuteQuota(); } absl::StatusOr<uint64_t> AcsAgentClient::GetBytesPerMinuteQuota() { absl::MutexLock lock(&reactor_mtx_); if (reactor_ == nullptr || stream_state_ != ClientState::kReady) { return absl::FailedPreconditionError("stream not initialized."); } return reactor_->GetBytesPerMinuteQuota(); } absl::Status AcsAgentClient::AddRequest(Request& request) { absl::Status latest_send_request_status = absl::OkStatus(); // TODO: Make the retry parameters configurable. for (int i = 0; i < 5; ++i) { // Generate a new message id for each attempt. { absl::MutexLock lock(&reactor_mtx_); request.set_message_id(CreateMessageUuid()); } latest_send_request_status = AddRequestAndWaitForResponse(request); if (latest_send_request_status.ok()) { return absl::OkStatus(); } if (latest_send_request_status.code() == absl::StatusCode::kAlreadyExists) { // Retry to send the request because the request was not even buffered in // the reactor. ABSL_VLOG(2) << absl::StrFormat( "Failed to add message with id: %s to reactor as the ongoing write " "takes. Retrying with a sleep delay.", request.message_id()); absl::SleepFor(absl::Seconds(1)); continue; } if (latest_send_request_status.code() == absl::StatusCode::kDeadlineExceeded) { // Give up retrying to send the request because the wait for the response // timed out. This generally means the server is not responding, and the // client should re-create the connection. ABSL_VLOG(1) << absl::StrFormat( "Successfully added message with id: %s to reactor, but timed out " "waiting for response from server. Please re-create the connection.", request.message_id()); break; } if (latest_send_request_status.code() == absl::StatusCode::kResourceExhausted) { // Retry to send the request because the resource exhausted with backoff. ABSL_VLOG(1) << absl::StrFormat( "Successfully added message with id: %s to reactor, but get a " "resource exhausted error. Retrying with a sleep delay.", request.message_id()); int delayMillis = std::min(250 * (i + 1), 1000); absl::SleepFor(absl::Milliseconds(delayMillis)); continue; } // Stop retrying if the send failed due to any other error. break; } return latest_send_request_status; } absl::Status AcsAgentClient::SendMessage(MessageBody message_body) { Request request; *request.mutable_message_body() = std::move(message_body); return AddRequest(request); } bool AcsAgentClient::IsDead() { // If the stream_state_ is kStreamFailedToInitialize or kShutdown, the client // is dead, no point to retry sending messages. If the stream_state_ is // kStreamClosed, the restart_client_thread_ will restart the client soon, the // caller can retry sending messages later with a backoff mechanism. If the // stream_state_ is kStreamNotInitialized, the client is not dead yet, just // waiting for the successful registration. absl::MutexLock lock(&reactor_mtx_); return stream_state_ == ClientState::kStreamFailedToInitialize || stream_state_ == ClientState::kShutdown; } absl::Status AcsAgentClient::Init() { // Register the connection. The registration request should only be sent once. std::unique_ptr<Request> registration_request = agent_communication::MakeRequestWithRegistration( CreateMessageUuid(), connection_id_.channel_id, connection_id_.resource_id); if (absl::Status status = RegisterConnection(*registration_request); !status.ok()) { return status; } stream_state_ = ClientState::kReady; return absl::OkStatus(); } absl::Status AcsAgentClient::AddRequestAndWaitForResponse( const Request& request) { // Queue up the message in the reactor and create the promise-future pair to // wait for the response from the server. bool added_to_reactor = false; const std::string& message_id = request.message_id(); std::promise<absl::Status> responsePromise; std::future<absl::Status> responseFuture = responsePromise.get_future(); { absl::MutexLock lock(&request_delivery_status_mtx_); attempted_requests_responses_sub_.emplace(message_id, std::move(responsePromise)); } // TODO: Make the retry parameters configurable. for (int i = 0; i < 5; ++i) { { absl::MutexLock lock(&reactor_mtx_); if (request.has_register_connection() && (stream_state_ != ClientState::kStreamNotInitialized && stream_state_ != ClientState::kStreamTemporarilyDown)) { return absl::InternalError( "The stream is not in the correct state to accept new registration " "request."); } if (request.has_message_body() && stream_state_ != ClientState::kReady) { return absl::FailedPreconditionError( "The stream is not ready to accept new MessageBody."); } if (reactor_->AddRequest(request)) { added_to_reactor = true; break; } } int delayMillis = std::min(100 * (1 << i), 2000); absl::SleepFor(absl::Milliseconds(delayMillis)); } if (!added_to_reactor) { // Set a dummy status value and clean up the promise if we fail to add the // request to the reactor. absl::MutexLock lock(&request_delivery_status_mtx_); SetValueAndRemovePromise(message_id, absl::OkStatus()); ABSL_VLOG(1) << absl::StrFormat( "Failed to add message with id: %s to reactor as the ongoing write " "takes too long.", message_id); return absl::AlreadyExistsError( "Failed to add message to reactor because the ongoing write takes too " "long."); } // Now that we have added the request to the reactor, wait for the response // from the server. std::future_status status = responseFuture.wait_for(max_wait_time_for_ack_); absl::Status received_status = absl::OkStatus(); if (status == std::future_status::ready) { received_status = responseFuture.get(); return received_status; } if (status == std::future_status::timeout) { ABSL_VLOG(1) << "timeout of waiting for response: " << message_id; received_status = absl::DeadlineExceededError(absl::StrFormat( "Timeout waiting for promise to be set for message with id: %s.", message_id)); } if (status == std::future_status::deferred) { ABSL_LOG(WARNING) << "This should never happen: get a deferred status from the future " "when waiting for response for message with id: " << message_id; received_status = absl::InternalError(absl::StrFormat( "Future is deferred for message with id: %s. This should never happen.", message_id)); } // Set a dummy status value and clean up the promise if we don't receive the // response from the server. absl::MutexLock lock(&request_delivery_status_mtx_); SetValueAndRemovePromise(message_id, absl::OkStatus()); return received_status; } bool AcsAgentClient::ShouldWakeUpClientReadMessage() { return !msg_responses_.empty() || client_read_state_ == ClientState::kShutdown; } void AcsAgentClient::ClientReadMessage() { while (true) { // This thread will be woken up by the ReactorReadCallback() when reactor // calls OnReadDone() or woken up by Shutdown(). // Within every iteration, if we don't shutdown, we will pop out 1 message, // exit the critical section, and then process the message by calling // AckOnSuccessfulDelivery() and read_callback_. In this way, we can release // the lock response_read_mtx_ and avoid blocking the OnReadDone() call of // the reactor. response_read_mtx_.LockWhen( absl::Condition(this, &AcsAgentClient::ShouldWakeUpClientReadMessage)); if (client_read_state_ == ClientState::kShutdown) { response_read_mtx_.Unlock(); return; } if (msg_responses_.empty()) { response_read_mtx_.Unlock(); continue; } Response response = std::move(msg_responses_.front()); msg_responses_.pop(); response_read_mtx_.Unlock(); // Exit the critical section and process the message. if (response.has_message_response()) { AckOnSuccessfulDelivery(response); } read_callback_(std::move(response)); } } void AcsAgentClient::ReactorReadCallback( Response response, AcsAgentClientReactor::RpcStatus status) { if (status == AcsAgentClientReactor::RpcStatus::kRpcClosedByClient) { ABSL_VLOG(1) << "RPC is closed by client, don't restart the stream."; return; } if (status == AcsAgentClientReactor::RpcStatus::kRpcClosedByServer) { ABSL_VLOG(1) << "RPC is closed by server, restarting the stream."; // Wakes up RestartReactor() to restart the stream. absl::MutexLock lock(&reactor_mtx_); if (stream_state_ != ClientState::kShutdown) { // If the stream is being shutdown by client, ie. stream_state_ is // kShutdown, we should not restart the stream. stream_state_ = ClientState::kStreamTemporarilyDown; } return; } // Wake up ClientReadMessage(). absl::MutexLock lock(&response_read_mtx_); msg_responses_.push(std::move(response)); ABSL_VLOG(2) << "Producer called with response: " << absl::StrCat(msg_responses_.front()); } absl::Status AcsAgentClient::RegisterConnection(const Request& request) { // Add request message to the reactor and create the promise-future pair to // wait for the response from the server. const std::string& message_id = request.message_id(); std::promise<absl::Status> responsePromise; std::future<absl::Status> responseFuture = responsePromise.get_future(); { absl::MutexLock lock(&request_delivery_status_mtx_); attempted_requests_responses_sub_.emplace(message_id, std::move(responsePromise)); } bool added_to_reactor = reactor_->AddRequest(request); if (!added_to_reactor) { absl::MutexLock lock(&request_delivery_status_mtx_); SetValueAndRemovePromise(message_id, absl::OkStatus()); return absl::InternalError( "Failed to add registration request to reactor, because the existing " "write buffer is full. This should never happen, because the " "registration request should be the first request sent to the server."); } // Now that we have added the request to the reactor, wait for the response // from the server. // Note that during the wait here, we hold the reactor_mtx_ lock. This is // intentional to keep the client from sending any other requests. The // downside is that if reactor calls OnReadDone(ok=false), which indicates the // failure of register connection, ReactorReadCallback() will not be able to // acquire the reactor_mtx_ lock until the wait here is done. This is fine // for now, because this function will still return a failed status, and wait // for the caller of this class to retry. std::future_status status = responseFuture.wait_for(max_wait_time_for_ack_); absl::Status received_status = absl::OkStatus(); if (status == std::future_status::ready) { received_status = responseFuture.get(); return received_status; } if (status == std::future_status::timeout) { ABSL_VLOG(1) << "timeout of waiting for response: " << message_id; received_status = absl::DeadlineExceededError(absl::StrFormat( "Timeout waiting for promise to be set for message with id: %s.", message_id)); } if (status == std::future_status::deferred) { ABSL_LOG(WARNING) << "This should never happen: get a deferred status from the future " "when waiting for response for message with id: " << message_id; received_status = absl::InternalError(absl::StrFormat( "Future is deferred for message with id: %s. This should never happen.", message_id)); } // Set a dummy status value and clean up the promise if we don't receive the // response from the server. absl::MutexLock lock(&request_delivery_status_mtx_); SetValueAndRemovePromise(message_id, absl::OkStatus()); return received_status; } void AcsAgentClient::RestartClient() { while (true) { reactor_mtx_.LockWhen(absl::Condition( +[](ClientState* stream_state) { return *stream_state == ClientState::kStreamTemporarilyDown || *stream_state == ClientState::kShutdown; }, &stream_state_)); // Terminate the thread if the client is being shutdown. if (stream_state_ == ClientState::kShutdown) { reactor_mtx_.Unlock(); return; } // Wait for the reactor to be terminated, capture the status, and then // restart the reactor. // TODO: need to determine if we want to retry based on the status, and add // retry logic with backoff mechanism. if (reactor_ != nullptr) { grpc::Status status = reactor_->Await(); ABSL_VLOG(1) << absl::StrFormat( "RestartReactor thread trying to restart the stream with previous " "termination status code: %d and message: %s and details: %s", status.error_code(), status.error_message(), status.error_details()); } std::unique_ptr<AcsStub> stub = GenerateConnectionIdAndStub(); if (stub == nullptr) { stream_state_ = ClientState::kStreamFailedToInitialize; reactor_mtx_.Unlock(); return; } reactor_ = std::make_unique<AcsAgentClientReactor>( std::move(stub), absl::bind_front(&AcsAgentClient::ReactorReadCallback, this), connection_id_); if (reactor_ == nullptr) { stream_state_ = ClientState::kStreamFailedToInitialize; reactor_mtx_.Unlock(); ABSL_LOG(WARNING) << "Failed to generate connection id and reactor."; return; } // Initialize the client. absl::Status init_status = Init(); if (!init_status.ok()) { stream_state_ = ClientState::kStreamFailedToInitialize; reactor_mtx_.Unlock(); return; } reactor_mtx_.Unlock(); } } std::unique_ptr<AcsStub> AcsAgentClient::GenerateConnectionIdAndStub() { absl::StatusOr<AgentConnectionId> new_connection_id = connection_id_generator_ != nullptr ? connection_id_generator_() : GenerateAgentConnectionId(connection_id_.channel_id, connection_id_.regional); if (!new_connection_id.ok()) { ABSL_LOG(WARNING) << "Failed to get connection id from agent connection " << "name: " << connection_id_.channel_id; return nullptr; } connection_id_ = *std::move(new_connection_id); if (stub_generator_ != nullptr) { return stub_generator_(); } return AcsAgentClientReactor::CreateStub(connection_id_.endpoint); } void AcsAgentClient::AckOnSuccessfulDelivery(const Response& response) { absl::MutexLock lock(&request_delivery_status_mtx_); const std::string& message_id = response.message_id(); if (attempted_requests_responses_sub_.find(message_id) == attempted_requests_responses_sub_.end()) { ABSL_LOG(WARNING) << absl::StrFormat( "Failed to find the promise for message with id: %s, but we got the " "response from the server with content: %s", message_id, response.DebugString()); return; } // Convert the google::rpc::status proto to absl::status object and set the // promise. absl::StatusCode code = static_cast<absl::StatusCode>( response.message_response().status().code()); SetValueAndRemovePromise( message_id, absl::Status(code, response.message_response().status().message())); } void AcsAgentClient::Shutdown() { if (read_response_thread_.joinable()) { { // Wakes up ClientReadMessage() to shut it down. absl::MutexLock lock(&response_read_mtx_); client_read_state_ = ClientState::kShutdown; } read_response_thread_.join(); } // Shutdown the restart_client_thread_ before the RPC is terminated. // Otherwise, the restart_client_thread_ may try to restart the RPC. if (restart_client_thread_.joinable()) { { absl::MutexLock lock(&reactor_mtx_); stream_state_ = ClientState::kShutdown; } restart_client_thread_.join(); } { absl::MutexLock lock(&reactor_mtx_); if (reactor_ != nullptr) { reactor_ = nullptr; } } } std::string AcsAgentClient::CreateMessageUuid() { int64_t random = absl::Uniform<int64_t>(gen_, 0, std::numeric_limits<int64_t>::max()); return absl::StrCat(random, "-", absl::ToUnixMicros(absl::Now())); } void AcsAgentClient::SetValueAndRemovePromise(const std::string& message_id, absl::Status status) { auto it = attempted_requests_responses_sub_.find(message_id); if (it == attempted_requests_responses_sub_.end()) { return; } it->second.set_value(std::move(status)); attempted_requests_responses_sub_.erase(it); } } // namespace agent_communication