dissociated-ipc/cudf-flight-client.cc (283 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <future>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <arrow/array.h>
#include <arrow/flight/client.h>
#include <arrow/gpu/cuda_api.h>
#include <arrow/ipc/api.h>
#include <arrow/util/endian.h>
#include <arrow/util/logging.h>
#include <arrow/util/uri.h>
#include "cudf-flight-ucx.h"
#include "ucx_client.h"
namespace flight = arrow::flight;
namespace ipc = arrow::ipc;
arrow::Result<ucp_tag_t> get_want_data_tag(const arrow::util::Uri& loc) {
ARROW_ASSIGN_OR_RAISE(auto query_params, loc.query_items());
for (auto& q : query_params) {
if (q.first == "want_data") {
return std::stoull(q.second);
}
}
return 0;
}
// utility client class to read a stream of data using the dissociated ipc
// protocol structure
class StreamReader {
public:
StreamReader(utils::Connection* ctrl_cnxn, utils::Connection* data_cnxn)
: ctrl_cnxn_{ctrl_cnxn}, data_cnxn_{data_cnxn} {
ARROW_UNUSED(ctrl_cnxn_->SetAMHandler(0, this, RecvMsg));
}
void set_data_mem_manager(std::shared_ptr<arrow::MemoryManager> mgr) {
if (!mgr) {
mm_ = arrow::CPUDevice::Instance()->default_memory_manager();
} else {
mm_ = std::move(mgr);
}
}
arrow::Status Start(ucp_tag_t ctrl_tag, ucp_tag_t data_tag, const std::string& ident) {
// consume the data and metadata streams simultaneously
ARROW_RETURN_NOT_OK(ctrl_cnxn_->SendTagSync(ctrl_tag, ident.data(), ident.size()));
ARROW_RETURN_NOT_OK(data_cnxn_->SendTagSync(data_tag, ident.data(), ident.size()));
std::thread(&StreamReader::run_data_loop, this).detach();
std::thread(&StreamReader::run_meta_loop, this).detach();
return arrow::Status::OK();
}
arrow::Result<std::shared_ptr<arrow::Schema>> Schema() {
// return the schema if we've already pulled it
if (schema_) {
return schema_;
}
// otherwise the next message should be the schema
ARROW_ASSIGN_OR_RAISE(auto msg, NextMsg());
ARROW_ASSIGN_OR_RAISE(schema_, ipc::ReadSchema(*msg, &dictionary_memo_));
return schema_;
}
arrow::Result<std::shared_ptr<arrow::RecordBatch>> Next() {
// we need the schema to read the record batch, also ensuring that
// we will retrieve the schema message which should be the first message
ARROW_ASSIGN_OR_RAISE(auto schema, Schema());
ARROW_ASSIGN_OR_RAISE(auto msg, NextMsg());
if (msg) {
return ipc::ReadRecordBatch(*msg, schema, &dictionary_memo_, ipc_options_);
}
// we've hit the end
return nullptr;
}
protected:
struct PendingMsg {
std::promise<std::unique_ptr<ipc::Message>> p;
std::shared_ptr<arrow::Buffer> metadata;
std::shared_ptr<arrow::Buffer> body;
StreamReader* rdr;
};
// data stream loop handler
void run_data_loop() {
if (arrow::cuda::IsCudaMemoryManager(*mm_)) {
// since we're in a new thread, we need to make sure to push the cuda context
// so that ucx uses the same cuda context as the Arrow data is using, otherwise
// the device pointers aren't valid
auto ctx = *(*arrow::cuda::AsCudaMemoryManager(mm_))->cuda_device()->GetContext();
cuCtxPushCurrent(reinterpret_cast<CUcontext>(ctx->handle()));
}
while (true) {
// progress the connection until an event happens
while (data_cnxn_->Progress()) {
}
{
// check if we have received any metadata which indicate we need to poll
// for a corresponding tagged data message
std::unique_lock<std::mutex> guard(polling_mutex_);
for (auto it = polling_map_.begin(); it != polling_map_.end();) {
auto maybe_tag =
data_cnxn_->ProbeForTag(ucp_tag_t(it->first), 0x00000000FFFFFFFF, 1);
if (!maybe_tag.ok()) {
ARROW_LOG(ERROR) << maybe_tag.status().ToString();
return;
}
auto tag_pair = maybe_tag.MoveValueUnsafe();
if (tag_pair.second != nullptr) {
// got one!
auto st = RecvTag(tag_pair.second, tag_pair.first, std::move(it->second));
if (!st.ok()) {
ARROW_LOG(ERROR) << st.ToString();
return;
}
it = polling_map_.erase(it);
} else {
++it;
}
}
}
// if the metadata stream has ended...
if (finished_metadata_.load()) {
// we are done if there's nothing left to poll for and nothing outstanding
std::lock_guard<std::mutex> guard(polling_mutex_);
if (polling_map_.empty() && outstanding_tags_.load() == 0) {
break;
}
}
}
}
// a mask to grab the byte indicating the body message type.
static constexpr uint64_t kbody_mask_ = 0x0100000000000000;
arrow::Status RecvTag(ucp_tag_message_h msg, ucp_tag_recv_info_t info_tag,
PendingMsg pending) {
++outstanding_tags_;
ARROW_ASSIGN_OR_RAISE(auto buf, mm_->AllocateBuffer(info_tag.length));
PendingMsg* new_pending = new PendingMsg(std::move(pending));
new_pending->body = std::move(buf);
new_pending->rdr = this;
return data_cnxn_->RecvTagData(
msg, reinterpret_cast<void*>(new_pending->body->address()), info_tag.length,
new_pending,
[](void* request, ucs_status_t status, const ucp_tag_recv_info_t* tag_info,
void* user_data) {
auto pending =
std::unique_ptr<PendingMsg>(reinterpret_cast<PendingMsg*>(user_data));
if (status != UCS_OK) {
ARROW_LOG(ERROR)
<< utils::FromUcsStatus("ucp_tag_recv_nbx_callback", status).ToString();
pending->p.set_value(nullptr);
return;
}
if (request) ucp_request_free(request);
if (tag_info->sender_tag & kbody_mask_) {
// pointer / offset list body
// not yet implemented
} else {
// full body bytes, use the pending metadata and read our IPC message
// as usual
auto msg = *ipc::Message::Open(pending->metadata, pending->body);
pending->p.set_value(std::move(msg));
--pending->rdr->outstanding_tags_;
}
},
(new_pending->body->is_cpu()) ? UCS_MEMORY_TYPE_HOST : UCS_MEMORY_TYPE_CUDA);
}
// handle the metadata stream
void run_meta_loop() {
while (!finished_metadata_.load()) {
// progress the connection until we get an event
while (ctrl_cnxn_->Progress()) {
}
{
std::unique_lock<std::mutex> guard(queue_mutex_);
while (!metadata_queue_.empty()) {
// handle any metadata messages in our queue
auto buf = std::move(metadata_queue_.front());
metadata_queue_.pop();
guard.unlock();
while (buf.wait_for(std::chrono::seconds(0)) != std::future_status::ready) {
ctrl_cnxn_->Progress();
}
std::shared_ptr<arrow::Buffer> buffer = buf.get();
if (static_cast<MetadataMsgType>(buffer->data()[0]) == MetadataMsgType::EOS) {
finished_metadata_.store(true);
guard.lock();
continue;
}
uint32_t sequence_number = utils::BytesToUint32LE(buffer->data() + 1);
auto metadata = SliceBuffer(buffer, 5, buffer->size() - 5);
// store a mapping of sequence numbers to std::future that returns the data
std::promise<std::unique_ptr<ipc::Message>> p;
{
std::lock_guard<std::mutex> lock(msg_mutex_);
msg_map_.insert({sequence_number, p.get_future()});
}
cv_progress_.notify_all();
auto msg = ipc::Message::Open(metadata, nullptr).ValueOrDie();
if (!ipc::Message::HasBody(msg->type())) {
p.set_value(std::move(msg));
guard.lock();
continue;
}
{
std::lock_guard<std::mutex> lock(polling_mutex_);
polling_map_.insert(
{sequence_number, PendingMsg{std::move(p), std::move(metadata)}});
}
guard.lock();
}
}
if (finished_metadata_.load()) break;
auto status = utils::FromUcsStatus("ucp_worker_wait", ctrl_cnxn_->WorkerWait());
if (!status.ok()) {
ARROW_LOG(ERROR) << status.ToString();
return;
}
}
}
arrow::Result<std::unique_ptr<ipc::Message>> NextMsg() {
// fetch the next IPC message by sequence number
const uint32_t counter = next_counter_++;
std::future<std::unique_ptr<ipc::Message>> futr;
{
std::unique_lock<std::mutex> lock(msg_mutex_);
if (msg_map_.empty() && finished_metadata_.load() && !outstanding_tags_.load()) {
return nullptr;
}
auto it = msg_map_.find(counter);
if (it == msg_map_.end()) {
// wait until we get a message for this sequence number
cv_progress_.wait(lock, [this, counter, &it] {
it = msg_map_.find(counter);
return it != msg_map_.end() || finished_metadata_.load();
});
}
futr = std::move(it->second);
msg_map_.erase(it);
}
// .get on a future will block until it either recieves a value or fails
return futr.get();
}
// callback function to recieve untagged "Active Messages"
static ucs_status_t RecvMsg(void* arg, const void* header, size_t header_len,
void* data, size_t length,
const ucp_am_recv_param_t* param) {
StreamReader* rdr = reinterpret_cast<StreamReader*>(arg);
DCHECK(length);
std::promise<std::unique_ptr<arrow::Buffer>> p;
{
std::lock_guard<std::mutex> lock(rdr->queue_mutex_);
rdr->metadata_queue_.push(p.get_future());
}
return rdr->ctrl_cnxn_->RecvAM(std::move(p), header, header_len, data, length, param);
}
private:
utils::Connection* ctrl_cnxn_;
utils::Connection* data_cnxn_;
std::shared_ptr<arrow::Schema> schema_;
ipc::DictionaryMemo dictionary_memo_;
ipc::IpcReadOptions ipc_options_;
std::shared_ptr<arrow::MemoryManager> mm_;
std::atomic<bool> finished_metadata_{false};
std::atomic<uint32_t> outstanding_tags_{0};
uint32_t next_counter_{0};
std::condition_variable cv_progress_;
std::mutex queue_mutex_;
std::queue<std::future<std::unique_ptr<arrow::Buffer>>> metadata_queue_;
std::mutex polling_mutex_;
std::unordered_map<uint32_t, PendingMsg> polling_map_;
std::mutex msg_mutex_;
std::unordered_map<uint32_t, std::future<std::unique_ptr<ipc::Message>>> msg_map_;
};
arrow::Status run_client(const std::string& addr, const int port) {
ARROW_ASSIGN_OR_RAISE(auto location, flight::Location::ForGrpcTcp(addr, port));
ARROW_ASSIGN_OR_RAISE(auto client, flight::FlightClient::Connect(location));
ARROW_ASSIGN_OR_RAISE(
auto info,
client->GetFlightInfo(flight::FlightDescriptor::Command("train.parquet")));
ARROW_LOG(DEBUG) << info->endpoints()[0].locations[0].ToString();
ARROW_LOG(DEBUG) << info->endpoints()[0].locations[1].ToString();
ARROW_ASSIGN_OR_RAISE(auto ctrl_uri, arrow::util::Uri::FromString(
info->endpoints()[0].locations[0].ToString()));
ARROW_ASSIGN_OR_RAISE(auto data_uri, arrow::util::Uri::FromString(
info->endpoints()[0].locations[1].ToString()));
ARROW_ASSIGN_OR_RAISE(ucp_tag_t ctrl_tag, get_want_data_tag(ctrl_uri));
ARROW_ASSIGN_OR_RAISE(ucp_tag_t data_tag, get_want_data_tag(data_uri));
const std::string& ident = info->endpoints()[0].ticket.ticket;
ARROW_ASSIGN_OR_RAISE(auto cuda_mgr, arrow::cuda::CudaDeviceManager::Instance());
ARROW_ASSIGN_OR_RAISE(auto device, cuda_mgr->GetDevice(0));
ARROW_ASSIGN_OR_RAISE(auto cuda_device, arrow::cuda::AsCudaDevice(device));
ARROW_ASSIGN_OR_RAISE(auto ctx, cuda_device->GetContext());
cuCtxPushCurrent(reinterpret_cast<CUcontext>(ctx->handle()));
ARROW_LOG(DEBUG) << device->ToString();
UcxClient ctrl_client, data_client;
ARROW_RETURN_NOT_OK(ctrl_client.Init(ctrl_uri.host(), ctrl_uri.port()));
ARROW_RETURN_NOT_OK(data_client.Init(data_uri.host(), data_uri.port()));
ARROW_ASSIGN_OR_RAISE(auto ctrl_cnxn, ctrl_client.CreateConn());
ARROW_ASSIGN_OR_RAISE(auto data_cnxn, data_client.CreateConn());
StreamReader rdr(ctrl_cnxn.get(), data_cnxn.get());
rdr.set_data_mem_manager(ctx->memory_manager());
ARROW_RETURN_NOT_OK(rdr.Start(ctrl_tag, data_tag, ident));
ARROW_ASSIGN_OR_RAISE(auto s, rdr.Schema());
std::cout << s->ToString() << std::endl;
while (true) {
ARROW_ASSIGN_OR_RAISE(auto batch, rdr.Next());
if (!batch) {
break;
}
std::cout << batch->num_columns() << " " << batch->num_rows() << std::endl;
std::cout << batch->column(0)->data()->buffers[1]->device()->ToString() << std::endl;
ARROW_ASSIGN_OR_RAISE(auto cpubatch,
batch->CopyTo(arrow::default_cpu_memory_manager()));
std::cout << cpubatch->ToString() << std::endl;
}
ARROW_CHECK_OK(ctrl_cnxn->Close());
ARROW_CHECK_OK(data_cnxn->Close());
return arrow::Status::OK();
}