cpp/acs_agent_helper.cc (188 lines of code) (raw):
#include "cpp/acs_agent_helper.h"
#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include "google/rpc/status.pb.h"
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "cpp/jwt.h"
#include "curl/curl.h"
#include "curl/easy.h"
namespace agent_communication {
namespace {
// Internal helper functions.
// Callback function for curl to write the response to the output string.
// Input: contents: the response data, size: the size of each element, nmemb:
// the number of elements, output: the output string.
// Returns: the number of bytes processed.
static size_t WriteCallback(void* contents, size_t size, size_t nmemb,
std::string* output) {
size_t total_size = size * nmemb;
output->append((char*)contents, total_size);
return total_size;
}
} // namespace
// Alias of the protobuf message types used in the ACS Agent Communication
// service in a .cc file.
using Request =
::google::cloud::agentcommunication::v1::StreamAgentMessagesRequest;
using Response =
::google::cloud::agentcommunication::v1::StreamAgentMessagesResponse;
constexpr absl::string_view kAcsTokenEndpointGce =
"instance/service-accounts/default/"
"identity?audience=agentcommunication.googleapis.com&format=full";
// TODO: b/384093718 - Update the endpoint once the token endpoint is finalized.
constexpr absl::string_view kAcsTokenEndpointGke =
"instance/gke/agent-communication-service/ncclmetrics-token";
// Internal helper struct to hold the ACS token and the parsed values from the
// token.
struct AcsToken {
std::string token;
std::string instance_id;
std::string project_number;
std::string zone;
};
std::unique_ptr<Request> MakeAck(std::string message_id) {
google::rpc::Status status;
status.set_code(0);
return MakeRequestWithResponse(std::move(message_id), std::move(status));
}
std::unique_ptr<Request> MakeRequestWithResponse(std::string message_id,
google::rpc::Status status) {
auto request = std::make_unique<Request>();
request->set_message_id(std::move(message_id));
request->mutable_message_response()->mutable_status()->CopyFrom(status);
return request;
}
std::unique_ptr<Request> MakeRequestWithRegistration(std::string message_id,
std::string channel_id,
std::string resource_id) {
auto request = std::make_unique<Request>();
request->set_message_id(std::move(message_id));
google::cloud::agentcommunication::v1::RegisterConnection
registration_connection;
registration_connection.set_channel_id(std::move(channel_id));
registration_connection.set_resource_id(std::move(resource_id));
*request->mutable_register_connection() = std::move(registration_connection);
return request;
}
std::unique_ptr<Response> MakeAckResponse(std::string message_id) {
google::rpc::Status status;
status.set_code(0);
return MakeResponseWithResponse(std::move(message_id), std::move(status));
}
std::unique_ptr<Response> MakeResponseWithResponse(std::string message_id,
google::rpc::Status status) {
auto response = std::make_unique<Response>();
response->set_message_id(std::move(message_id));
response->mutable_message_response()->mutable_status()->CopyFrom(status);
return response;
}
absl::StatusOr<std::string> CurlHttpGet(const std::string& url,
const std::string& header) {
CURL* curl;
CURLcode res;
std::string read_buffer;
curl = curl_easy_init();
if (curl == nullptr) {
ABSL_LOG(ERROR) << "Failed to initialize curl.";
return absl::InternalError("Failed to initialize curl.");
}
// Set URL.
res = curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
if (res != CURLE_OK) {
ABSL_LOG(ERROR) << "Failed to set URL: " << url;
curl_easy_cleanup(curl);
return absl::InternalError(absl::StrCat(
"Failed to set URL: ", url, " with error: ", curl_easy_strerror(res)));
}
// Set header.
struct curl_slist* headers = nullptr;
headers = curl_slist_append(headers, header.c_str());
res = curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
if (res != CURLE_OK || headers == nullptr) {
ABSL_LOG(ERROR) << "Failed to set header: " << header;
curl_easy_cleanup(curl);
if (headers != nullptr) {
curl_slist_free_all(headers);
}
return absl::InternalError(
absl::StrCat("Failed to set header: ", header,
" with error: ", curl_easy_strerror(res)));
}
// Set the write callback function and its data.
// No need to check the return value as they will both return CURLE_OK.
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &read_buffer);
res = curl_easy_perform(curl);
curl_easy_cleanup(curl);
curl_slist_free_all(headers);
if (res != CURLE_OK) {
ABSL_LOG(ERROR) << "curl_easy_perform() failed: "
<< curl_easy_strerror(res);
return absl::InternalError(absl::StrCat(
"curl_easy_perform() failed with error: ", curl_easy_strerror(res)));
}
ABSL_VLOG(1) << "Got metadata for key: " << url
<< " and its value is: " << read_buffer;
return read_buffer;
}
absl::StatusOr<std::string> GetMetadata(absl::string_view key) {
return CurlHttpGet(
absl::StrCat("http://metadata.google.internal/computeMetadata/v1/", key),
"Metadata-Flavor: Google");
}
absl::StatusOr<AcsToken> ParseAcsToken(absl::string_view endpoint) {
absl::StatusOr<std::string> token = GetMetadata(endpoint);
if (!token.ok()) {
return token.status();
}
ABSL_VLOG(2) << "Successfully got token from metadata service: " << *token;
absl::StatusOr<std::string> instance_id = GetValueFromTokenPayloadWithKeys(
*token, {"google", "compute_engine", "instance_id"});
if (!instance_id.ok()) {
return instance_id.status();
}
ABSL_VLOG(2) << "Successfully got instance_id from metadata service: "
<< *instance_id;
absl::StatusOr<std::string> project_number = GetValueFromTokenPayloadWithKeys(
*token, {"google", "compute_engine", "project_number"});
if (!project_number.ok()) {
return project_number.status();
}
ABSL_VLOG(2) << "Successfully got project_number from metadata service: "
<< *project_number;
absl::StatusOr<std::string> zone = GetValueFromTokenPayloadWithKeys(
*token, {"google", "compute_engine", "zone"});
if (!zone.ok()) {
return zone.status();
}
ABSL_VLOG(2) << "Successfully got zone from metadata service: " << *zone;
return AcsToken{.token = *std::move(token),
.instance_id = *std::move(instance_id),
.project_number = *std::move(project_number),
.zone = *std::move(zone)};
}
absl::StatusOr<AgentConnectionId> GenerateAgentConnectionId(
std::string channel_id, bool regional) {
absl::StatusOr<AcsToken> AcsToken = ParseAcsToken(kAcsTokenEndpointGce);
if (!AcsToken.ok()) {
// If the token is not available from the GCE endpoint, try the GKE
// endpoint.
AcsToken = ParseAcsToken(kAcsTokenEndpointGke);
if (!AcsToken.ok()) {
return AcsToken.status();
}
}
const std::string& zone = AcsToken->zone;
// Deduce the location from the zone.
// If regional is true, the location is the zone without the last two
// characters. Otherwise, the location is the zone itself.
// Example: zone: us-central1-a -> region: us-central1
size_t last_hyphen_index = zone.find_last_of('-');
if (last_hyphen_index == std::string::npos) {
return absl::InternalError(
absl::StrCat("Wrong format of zone from metadata service: ", zone));
}
std::string location = regional ? zone.substr(0, last_hyphen_index) : zone;
std::string endpoint =
absl::StrContainsIgnoreCase(location, "staging")
? absl::StrCat(location,
"-agentcommunication.sandbox.googleapis.com:443")
: absl::StrCat(location, "-agentcommunication.googleapis.com:443");
std::string resource_id =
absl::StrFormat("projects/%s/zones/%s/instances/%s",
AcsToken->project_number, zone, AcsToken->instance_id);
return AgentConnectionId{.token = std::move(AcsToken->token),
.resource_id = std::move(resource_id),
.channel_id = std::move(channel_id),
.endpoint = std::move(endpoint),
.regional = regional};
}
} // namespace agent_communication