dissociated-ipc/ucx_server.cc (211 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #include <arpa/inet.h> #include <netdb.h> #include "ucx_server.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/io_util.h" #include "arrow/util/string.h" namespace { arrow::Result<std::shared_ptr<utils::UcpContext>> init_ucx( struct sockaddr_storage connect_addr) { ucp_config_t* ucp_config; ucp_params_t ucp_params; ucs_status_t status = ucp_config_read(nullptr, nullptr, &ucp_config); RETURN_NOT_OK(utils::FromUcsStatus("ucp_config_read", status)); // if location is ipv6, adjust config if (connect_addr.ss_family == AF_INET6) { status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6"); RETURN_NOT_OK(utils::FromUcsStatus("ucp_config_modify", status)); } std::memset(&ucp_params, 0, sizeof(ucp_params)); ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_NAME | UCP_PARAM_FIELD_MT_WORKERS_SHARED; ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_WAKEUP | UCP_FEATURE_STREAM; ucp_params.mt_workers_shared = UCS_THREAD_MODE_MULTI; ucp_params.name = "cuda-flight-ucx"; ucp_context_h ucp_context; status = ucp_init(&ucp_params, ucp_config, &ucp_context); ucp_config_release(ucp_config); RETURN_NOT_OK(utils::FromUcsStatus("ucp_init", status)); return std::make_shared<utils::UcpContext>(ucp_context); } arrow::Result<std::shared_ptr<utils::UcpWorker>> create_listener_worker( std::shared_ptr<utils::UcpContext> ctx) { ucp_worker_params_t worker_params; ucs_status_t status; std::memset(&worker_params, 0, sizeof(worker_params)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; ucp_worker_h worker; status = ucp_worker_create(ctx->get(), &worker_params, &worker); RETURN_NOT_OK(utils::FromUcsStatus("ucp_worker_create", status)); return std::make_shared<utils::UcpWorker>(std::move(ctx), worker); } } // namespace arrow::Status UcxServer::Init(const std::string& host, const int32_t port) { struct sockaddr_storage listen_addr; ARROW_ASSIGN_OR_RAISE(auto addrlen, utils::to_sockaddr(host, port, &listen_addr)); ARROW_ASSIGN_OR_RAISE(ucp_context_, init_ucx(listen_addr)); ARROW_ASSIGN_OR_RAISE(worker_conn_, create_listener_worker(ucp_context_)); { ucp_listener_params_t params; ucs_status_t status; params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; params.sockaddr.addr = reinterpret_cast<const sockaddr*>(&listen_addr); params.sockaddr.addrlen = addrlen; params.conn_handler.cb = HandleIncomingConnection; params.conn_handler.arg = this; status = ucp_listener_create(worker_conn_->get(), &params, &listener_); RETURN_NOT_OK(utils::FromUcsStatus("ucp_listener_create", status)); // get real address/port ucp_listener_attr_t attr; attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR; status = ucp_listener_query(listener_, &attr); RETURN_NOT_OK(utils::FromUcsStatus("ucp_listener_query", status)); std::string raw_uri = "ucx://"; if (host.find(":") != std::string::npos) { raw_uri += '['; raw_uri += host; raw_uri += ']'; } else { raw_uri += host; } using arrow::internal::ToChars; raw_uri += ":"; raw_uri += ToChars(ntohs(reinterpret_cast<const sockaddr_in*>(&attr.sockaddr)->sin_port)); ARROW_ASSIGN_OR_RAISE(location_, arrow::flight::Location::Parse(raw_uri)); } { listening_.store(true); std::thread listener_thread(&UcxServer::DriveConnections, this); listener_thread_.swap(listener_thread); } return arrow::Status::OK(); } arrow::Status UcxServer::Wait() { std::lock_guard<std::mutex> guard(join_mutex_); try { listener_thread_.join(); } catch (const std::system_error& e) { if (e.code() != std::errc::invalid_argument) { return arrow::Status::UnknownError("could not Wait(): ", e.what()); } // else server wasn't running anyways } return arrow::Status::OK(); } arrow::Status UcxServer::Shutdown() { if (!listening_.load()) return arrow::Status::OK(); arrow::Status status; // wait for current running things to finish listening_.store(false); RETURN_NOT_OK( utils::FromUcsStatus("ucp_worker_signal", ucp_worker_signal(worker_conn_->get()))); status &= Wait(); { // reject all pending connections std::lock_guard<std::mutex> guard(pending_connections_mutex_); while (!pending_connections_.empty()) { status &= utils::FromUcsStatus( "ucp_listener_reject", ucp_listener_reject(listener_, pending_connections_.front())); pending_connections_.pop(); } ucp_listener_destroy(listener_); worker_conn_.reset(); } ucp_context_.reset(); return status; } void UcxServer::DriveConnections() { while (listening_.load()) { // wait for server to recieve connection request from client while (ucp_worker_progress(worker_conn_->get())) { } { // check for requests in queue std::lock_guard<std::mutex> guard(pending_connections_mutex_); while (!pending_connections_.empty()) { ucp_conn_request_h request = pending_connections_.front(); pending_connections_.pop(); std::thread(&UcxServer::HandleConnection, this, request).detach(); } } // check listening_ in case we're shutting down. // it's possible that shutdown was called while we were in // ucp_worker_progress above, in which case if we don't check // listening_ here, we'll enter ucp_worker_wait and get stuck. if (!listening_.load()) break; auto status = ucp_worker_wait(worker_conn_->get()); if (status != UCS_OK) { ARROW_LOG(WARNING) << utils::FromUcsStatus("ucp_worker_wait", status).ToString(); } } } void UcxServer::HandleConnection(ucp_conn_request_h request) { using arrow::internal::ToChars; std::string peer = "unknown:" + ToChars(counter_++); { ucp_conn_request_attr_t request_attr; std::memset(&request_attr, 0, sizeof(request_attr)); request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR; if (ucp_conn_request_query(request, &request_attr) == UCS_OK) { ARROW_UNUSED(utils::SockaddrToString(request_attr.client_address).Value(&peer)); } } ARROW_LOG(DEBUG) << peer << ": Received connection request"; auto maybe_worker = CreateWorker(); if (!maybe_worker.ok()) { ARROW_LOG(ERROR) << peer << ": failed to create worker" << maybe_worker.status().ToString(); auto status = ucp_listener_reject(listener_, request); if (status != UCS_OK) { ARROW_LOG(ERROR) << peer << ": " << utils::FromUcsStatus("ucp_listener_reject", status).ToString(); } return; } auto worker = maybe_worker.MoveValueUnsafe(); worker->conn_ = std::make_unique<utils::Connection>(worker->worker_); auto status = worker->conn_->CreateEndpoint(request); if (!status.ok()) { ARROW_LOG(ERROR) << peer << ": failed to create endpoint and connection: " << status.ToString(); return; } if (cuda_context_) { auto result = cuCtxPushCurrent(reinterpret_cast<CUcontext>(cuda_context_->handle())); if (result != CUDA_SUCCESS) { const char* err_name = "\0"; const char* err_string = "\0"; cuGetErrorName(result, &err_name); cuGetErrorString(result, &err_string); ARROW_LOG(ERROR) << peer << ": failed pushing cuda context on thread: " << err_name << " - " << err_string; return; } } auto st = do_work(worker.get()); if (!st.ok()) { ARROW_LOG(ERROR) << peer << ": error from do_work: " << st.ToString(); } while (!worker->conn_->is_closed()) { worker->conn_->Progress(); } // clean up status = worker->conn_->Close(); if (!status.ok()) { ARROW_LOG(ERROR) << peer << ": failed to close worker connection: " << status.ToString(); } worker->worker_.reset(); worker->conn_.reset(); ARROW_LOG(DEBUG) << peer << ": disconnected"; } arrow::Result<std::shared_ptr<UcxServer::ClientWorker>> UcxServer::CreateWorker() { auto worker = std::make_shared<ClientWorker>(); ucp_worker_params_t worker_params; std::memset(&worker_params, 0, sizeof(worker_params)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_FLAGS; worker_params.thread_mode = UCS_THREAD_MODE_MULTI; worker_params.flags = UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK; ucp_worker_h ucp_worker; ARROW_RETURN_NOT_OK(utils::FromUcsStatus( "ucp_worker_create", ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker))); worker->worker_ = std::make_shared<utils::UcpWorker>(ucp_context_, ucp_worker); ARROW_RETURN_NOT_OK(setup_handlers(worker.get())); return worker; }