dissociated-ipc/ucx_conn.h (61 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #pragma once #include <ucp/api/ucp.h> #include <future> #include <memory> #include <utility> #include "arrow/util/logging.h" #include "ucx_utils.h" namespace utils { class Connection { public: explicit Connection(std::shared_ptr<UcpWorker> worker); Connection(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint); ARROW_DISALLOW_COPY_AND_ASSIGN(Connection); ARROW_DEFAULT_MOVE_AND_ASSIGN(Connection); ~Connection() { DCHECK(!ucp_worker_) << "Connection was not closed!"; } arrow::Status CreateEndpoint(ucp_conn_request_h request); arrow::Status CreateEndpoint(const sockaddr_storage& addr, const size_t addrlen); arrow::Status Flush(); arrow::Status Close(); inline bool is_closed() const { return closed_; } inline unsigned int Progress() { return ucp_worker_progress(ucp_worker_->get()); } inline ucs_status_t WorkerWait() { return ucp_worker_wait(ucp_worker_->get()); } arrow::Status SetAMHandler(unsigned int id, void* user_data, ucp_am_recv_callback_t cb); arrow::Result<std::pair<ucp_tag_recv_info_t, ucp_tag_message_h>> ProbeForTag( ucp_tag_t tag, ucp_tag_t mask, int remove); arrow::Result<std::pair<ucp_tag_recv_info_t, ucp_tag_message_h>> ProbeForTagSync( ucp_tag_t tag, ucp_tag_t mask, int remove); arrow::Status RecvTagData(ucp_tag_message_h msg, void* buffer, const size_t count, void* user_data, ucp_tag_recv_nbx_callback_t cb, const ucs_memory_type_t memory_type); ucs_status_t RecvAM(std::promise<std::unique_ptr<arrow::Buffer>> p, const void* header, const size_t header_length, void* data, const size_t data_length, const ucp_am_recv_param_t* param); arrow::Status SendAM(unsigned int id, const void* data, const int64_t size); arrow::Status SendAMIov(unsigned int id, const ucp_dt_iov_t* iov, const size_t iov_cnt, void* user_data, ucp_send_nbx_callback_t cb, const ucs_memory_type_t memory_type); arrow::Status SendTagIov(ucp_tag_t tag, const ucp_dt_iov_t* iov, const size_t iov_cnt, void* user_data, ucp_send_nbx_callback_t cb, const ucs_memory_type_t memory_type); arrow::Status SendTagSync(ucp_tag_t tag, const void* buffer, const size_t count); protected: static void err_cb(void* arg, ucp_ep_h ep, ucs_status_t status) { if (!is_ignorable_disconnect_error(status)) { ARROW_LOG(DEBUG) << FromUcsStatus("error handling callback", status).ToString(); } Connection* cnxn = reinterpret_cast<Connection*>(arg); cnxn->closed_ = true; } inline arrow::Status CheckClosed() { if (!remote_endpoint_) { return arrow::Status::Invalid("connection is closed"); } return arrow::Status::OK(); } private: std::shared_ptr<utils::UcpWorker> ucp_worker_; ucp_ep_h remote_endpoint_; bool closed_{false}; }; } // namespace utils