in cpp/src/arrow/flight/flight_benchmark.cc [398:550]
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::cout << "Testing method: ";
if (FLAGS_test_put) {
std::cout << "DoPut";
} else {
std::cout << "DoGet";
}
std::cout << std::endl;
arrow::flight::FlightCallOptions call_options;
if (!FLAGS_compression.empty()) {
if (!FLAGS_test_put) {
std::cerr << "Compression is only useful for Put test now, "
"please append \"-test_put\" to command line"
<< std::endl;
std::abort();
}
// "zstd" -> name = "zstd", level = default
// "zstd:7" -> name = "zstd", level = 7
const size_t delim = FLAGS_compression.find(":");
const std::string name = FLAGS_compression.substr(0, delim);
const std::string level_str =
delim == std::string::npos
? ""
: FLAGS_compression.substr(delim + 1, FLAGS_compression.length() - delim - 1);
const int level = level_str.empty() ? arrow::util::kUseDefaultCompressionLevel
: std::stoi(level_str);
const auto type = arrow::util::Codec::GetCompressionType(name).ValueOrDie();
auto codec = arrow::util::Codec::Create(type, level).ValueOrDie();
std::cout << "Compression method: " << name;
if (!level_str.empty()) {
std::cout << ", level " << level;
}
std::cout << std::endl;
call_options.write_options.codec = std::move(codec);
}
if (!FLAGS_data_file.empty() && !FLAGS_test_put) {
std::cerr << "A data file can only be specified with \"-test_put\"" << std::endl;
return 1;
}
std::unique_ptr<arrow::flight::TestServer> server;
std::vector<std::string> server_args;
server_args.push_back("-transport");
server_args.push_back(FLAGS_transport);
arrow::flight::Location location;
auto options = arrow::flight::FlightClientOptions::Defaults();
if (FLAGS_transport == "grpc") {
if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
if (FLAGS_server_unix == "") {
FLAGS_server_unix = "/tmp/flight-bench-spawn.sock";
std::cout << "Using spawned Unix server" << std::endl;
server.reset(
new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix));
} else {
std::cout << "Using standalone Unix server" << std::endl;
}
std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
ABORT_NOT_OK(
arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix).Value(&location));
} else {
if (FLAGS_server_host == "") {
FLAGS_server_host = "localhost";
std::cout << "Using spawned TCP server" << std::endl;
server.reset(
new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port));
if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
std::cout << "Enabling TLS for spawned server" << std::endl;
server_args.push_back("-cert_file");
server_args.push_back(FLAGS_cert_file);
server_args.push_back("-key_file");
server_args.push_back(FLAGS_key_file);
} else {
std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
return 1;
}
}
} else {
std::cout << "Using standalone TCP server" << std::endl;
}
if (server) {
if (FLAGS_cuda && FLAGS_test_put) {
server_args.push_back("-cuda");
}
ABORT_NOT_OK(server->Start(server_args));
}
std::cout << "Server host: " << FLAGS_server_host << std::endl
<< "Server port: " << FLAGS_server_port << std::endl;
if (FLAGS_cert_file.empty()) {
ABORT_NOT_OK(
arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_server_port)
.Value(&location));
} else {
ABORT_NOT_OK(
arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_server_port)
.Value(&location));
options.disable_server_verification = true;
}
}
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
}
if (FLAGS_cuda) {
#ifdef ARROW_CUDA
if (FLAGS_test_put && !server) {
std::cerr << "Warning: -cuda has no effect with -test_put" << std::endl;
std::cerr << "Warning: (enable it on the server instead)" << std::endl;
}
arrow::cuda::CudaDeviceManager* manager = nullptr;
std::shared_ptr<arrow::cuda::CudaDevice> device;
ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager));
ABORT_NOT_OK(manager->GetDevice(0).Value(&device));
call_options.memory_manager = device->default_memory_manager();
// Needed to prevent UCX warning
// cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid
// device context
std::shared_ptr<arrow::cuda::CudaContext> context;
ABORT_NOT_OK(device->GetContext().Value(&context));
auto cuda_status = cuCtxPushCurrent(reinterpret_cast<CUcontext>(context->handle()));
if (cuda_status != CUDA_SUCCESS) {
ARROW_LOG(WARNING) << "CUDA error " << cuda_status;
}
#else
std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl;
return 1;
#endif
}
auto client = arrow::flight::FlightClient::Connect(location, options).ValueOrDie();
ABORT_NOT_OK(arrow::flight::WaitForReady(client.get(), call_options));
arrow::Status s = arrow::flight::RunPerformanceTest(client.get(), options, call_options,
FLAGS_test_put);
if (server) {
server->Stop();
}
if (!s.ok()) {
std::cerr << "Failed with error: << " << s.ToString() << std::endl;
}
return 0;
}