e2e-examples/gcs/sample/main.cc (124 lines of code) (raw):

// Copyright 2022 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include <grpcpp/channel.h> #include <grpcpp/client_context.h> #include <grpcpp/create_channel.h> #include <fstream> #include <memory> #include <streambuf> #include <string> #include <thread> #include <unordered_map> #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "absl/strings/str_format.h" #include "channel_manager.h" using ReadObjectRequest = google::storage::v2::ReadObjectRequest; using ReadObjectResponse = google::storage::v2::ReadObjectResponse; ABSL_FLAG(bool, directpath, true, "Whether to allow DirectPath"); ABSL_FLAG(std::string, access_token, "", "Access token for auth"); ABSL_FLAG(std::string, host, "", "Host to reach"); ABSL_FLAG(std::string, bucket, "gcs-grpc-team-veblush1", "Bucket to fetch object from"); ABSL_FLAG(std::string, object, "1MB.bin", "Object to download"); ABSL_FLAG(int, runs, 10, "Number of times to run the download"); ABSL_FLAG(int, threads, 8, "The number of threads running downloding objects"); ABSL_FLAG(int, channels, 4, "The max number of gRPC channels"); ABSL_FLAG(int, retries, 10, "The max number of gRPC retries"); std::shared_ptr<grpc::Channel> CreateBenchmarkGrpcChannel() { if (absl::GetFlag(FLAGS_access_token).empty()) { grpc::ChannelArguments channel_args; channel_args.SetServiceConfigJSON( "{\"loadBalancingConfig\":[{\"grpclb\":{" "\"childPolicy\":[{\"pick_first\":{}}]}}]" "}"); if (!absl::GetFlag(FLAGS_directpath)) { channel_args.SetInt("grpc.dns_enable_srv_queries", 0); // Disable DirectPath } std::shared_ptr<grpc::Channel> channel = grpc::CreateCustomChannel( std::string(absl::GetFlag(FLAGS_host)), grpc::GoogleDefaultCredentials(), channel_args); return channel; } else { std::shared_ptr<grpc::ChannelCredentials> credentials; std::shared_ptr<grpc::ChannelCredentials> channel_credentials = grpc::SslCredentials(grpc::SslCredentialsOptions()); if (absl::GetFlag(FLAGS_access_token) == "-") { credentials = channel_credentials; } else { std::shared_ptr<grpc::CallCredentials> call_credentials = grpc::AccessTokenCredentials( std::string(absl::GetFlag(FLAGS_access_token))); credentials = grpc::CompositeChannelCredentials(channel_credentials, call_credentials); } std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel( std::string(absl::GetFlag(FLAGS_host)), credentials); return channel; } grpc::ChannelArguments channel_args; channel_args.SetServiceConfigJSON( "{\"loadBalancingConfig\":[{\"grpclb\":{" "\"childPolicy\":[{\"pick_first\":{}}]}}]" "}"); std::shared_ptr<grpc::Channel> channel = grpc::CreateCustomChannel(std::string(absl::GetFlag(FLAGS_host)), grpc::GoogleDefaultCredentials(), channel_args); return channel; } void worker(ChannelManager& channel_manager, std::atomic_size_t& read_bytes) { // Downloads a given file N times. for (int i = 0; i < absl::GetFlag(FLAGS_runs); i++) { for (int j = 0; j < absl::GetFlag(FLAGS_retries); j++) { auto channel_handle = channel_manager.GetHandle(); ReadObjectRequest request; request.set_bucket("projects/_/buckets/" + absl::GetFlag(FLAGS_bucket)); request.set_object(absl::GetFlag(FLAGS_object)); grpc::ClientContext context; std::unique_ptr<grpc::ClientReader<ReadObjectResponse>> reader = channel_handle.GetStub<google::storage::v2::Storage::Stub>() ->ReadObject(&context, request); int64_t total_bytes = 0; ReadObjectResponse response; while (reader->Read(&response)) { int64_t content_size = response.checksummed_data().content().size(); total_bytes += content_size; } read_bytes += total_bytes; auto status = reader->Finish(); if (!status.ok()) { std::cerr << absl::StrFormat( "Download Error: Code=%d Message=%s Retries=%d from %s", status.error_code(), status.error_message(), j, context.peer()) << std::endl; } channel_handle.OnRpcDone(status); // In case of retriable error, it's going to retry it if (status.error_code() != grpc::StatusCode::CANCELLED && status.error_code() != grpc::StatusCode::DEADLINE_EXCEEDED) { break; } } } } int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); ChannelManager channel_manager(absl::GetFlag(FLAGS_channels), &CreateBenchmarkGrpcChannel); // Spawns benchmark runners and waits until they're done. absl::Time run_start = absl::Now(); std::atomic_size_t total_size(0); std::vector<std::thread> runner_threads; for (int i = 0; i < absl::GetFlag(FLAGS_threads); i++) { runner_threads.emplace_back([&channel_manager, &total_size]() { worker(channel_manager, total_size); }); } std::for_each(runner_threads.begin(), runner_threads.end(), [](std::thread& t) { t.join(); }); absl::Time run_end = absl::Now(); // Shows the result. double elapsed = absl::ToDoubleSeconds(run_end - run_start); std::cout << "Data: " << total_size << " bytes" << std::endl; std::cout << "Elapsed: " << elapsed << " sec" << std::endl; std::cout << "Throughput: " << total_size / elapsed / 1024 / 1024 << " MB/s" << std::endl; return 0; }