dissociated-ipc/ucx_utils.h (84 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <ucp/api/ucp.h>
#include <arrow/buffer.h>
#include <arrow/status.h>
#include <arrow/util/endian.h>
#include <arrow/util/logging.h>
#include <arrow/util/ubsan.h>
namespace utils {
static inline void Uint32ToBytesLE(const uint32_t in, uint8_t* out) {
arrow::util::SafeStore(out, arrow::bit_util::ToLittleEndian(in));
}
static inline uint32_t BytesToUint32LE(const uint8_t* in) {
return arrow::bit_util::FromLittleEndian(arrow::util::SafeLoadAs<uint32_t>(in));
}
class UcpContext final {
public:
UcpContext() = default;
explicit UcpContext(ucp_context_h context) : ucp_context_(context) {}
~UcpContext() {
if (ucp_context_) ucp_cleanup(ucp_context_);
ucp_context_ = nullptr;
}
ucp_context_h get() const {
DCHECK(ucp_context_);
return ucp_context_;
}
private:
ucp_context_h ucp_context_{nullptr};
};
class UcpWorker final {
public:
UcpWorker() = default;
UcpWorker(std::shared_ptr<UcpContext> context, ucp_worker_h worker)
: ucp_context_(std::move(context)), ucp_worker_(worker) {}
~UcpWorker() {
if (ucp_worker_) ucp_worker_destroy(ucp_worker_);
ucp_worker_ = nullptr;
}
ucp_worker_h get() const { return ucp_worker_; }
const UcpContext& context() const { return *ucp_context_; }
private:
ucp_worker_h ucp_worker_{nullptr};
std::shared_ptr<UcpContext> ucp_context_;
};
class UcxStatusDetail : public arrow::StatusDetail {
public:
explicit UcxStatusDetail(ucs_status_t status) : status_(status) {}
static constexpr char const kTypeId[] = "ucx::UcxStatusDetail";
const char* type_id() const override { return kTypeId; }
std::string ToString() const override;
static ucs_status_t Unwrap(const arrow::Status& status);
private:
ucs_status_t status_;
};
arrow::Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status);
class UcxDataBuffer : public arrow::Buffer {
public:
UcxDataBuffer(std::shared_ptr<UcpWorker> worker, void* data, const size_t size)
: arrow::Buffer(reinterpret_cast<uint8_t*>(data), static_cast<int64_t>(size)),
worker_(std::move(worker)) {}
~UcxDataBuffer() override {
ucp_am_data_release(worker_->get(),
const_cast<void*>(reinterpret_cast<const void*>(data())));
}
private:
std::shared_ptr<UcpWorker> worker_;
};
arrow::Result<size_t> to_sockaddr(const std::string& host, const int32_t port,
struct sockaddr_storage* addr);
arrow::Result<std::string> SockaddrToString(const struct sockaddr_storage& address);
static inline bool is_ignorable_disconnect_error(ucs_status_t ucs_status) {
// not connected, connection reset: we're already disconnected
// timeout: most likely disconnected, but we can't tell from our end
switch (ucs_status) {
case UCS_OK:
case UCS_ERR_ENDPOINT_TIMEOUT:
case UCS_ERR_NOT_CONNECTED:
case UCS_ERR_CONNECTION_RESET:
return true;
}
return false;
}
} // namespace utils