gloo/rendezvous/context.cc (176 lines of code) (raw):

/** * Copyright (c) 2017-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include "gloo/rendezvous/context.h" #include "gloo/common/logging.h" #include "gloo/transport/address.h" #ifdef _WIN32 #include <winsock2.h> #include <gloo/common/win.h> #else #include <unistd.h> #endif namespace gloo { namespace rendezvous { constexpr int64_t HOSTNAME_MAX_SIZE = 256; Context::Context(int rank, int size, int base) : ::gloo::Context(rank, size, base) { } Context::~Context() { } std::vector<char> Context::extractAddress( std::vector<char>& allAddrs, int i) { // Extract address from the list of all addresses int adjRank = (rank > i ? rank - 1 : rank); // Adjust for the fact that nodes do not store address for themselves int addrSize = allAddrs.size() / (size - 1); return std::vector<char>(allAddrs.begin() + adjRank * addrSize, allAddrs.begin() + (adjRank + 1) * addrSize); } void Context::connectFullMesh( rendezvous::Store& store, std::shared_ptr<transport::Device>& dev) { std::vector<char> allBytes; int localRank = 0; // Get Hostname using syscall char hostname[HOSTNAME_MAX_SIZE]; // NOLINT int rv = gethostname(hostname, HOSTNAME_MAX_SIZE); if (rv != 0) { throw std::system_error(errno, std::system_category()); } auto localHostName = std::string(hostname); // Add global rank <> hostname pair to the Store. This store is then passed // to Gloo when connectFullMesh is called, where Gloo uses the global rank <> // hostname mapping to compute local ranks. std::string localKey("rank_" + std::to_string(rank)); const std::vector<char> value(localHostName.begin(), localHostName.end()); store.set(localKey, value); for (int i = 0; i < size; i++) { if (i == rank) { break; } std::string key("rank_" + std::to_string(i)); auto val = store.get(key); auto hostName = std::string((const char*)val.data(), val.size()); if (hostName == localHostName) { localRank++; } } // Create pairs auto transportContext = dev->createContext(rank, size); transportContext->setTimeout(getTimeout()); for (int i = 0; i < size; i++) { if (i == rank) { continue; } auto& pair = transportContext->createPair(i); pair->setLocalRank(localRank); auto addrBytes = pair->address().bytes(); allBytes.insert(allBytes.end(), addrBytes.begin(), addrBytes.end()); } std::ostringstream storeKey; storeKey << rank; store.set(storeKey.str(), allBytes); // Connect every pair for (int i = 0; i < size; i++) { if (i == rank) { continue; } // Wait for address of other side of this pair to become available std::ostringstream key; key << i; store.wait({key.str()}, getTimeout()); // Connect to other side of this pair auto allAddrs = store.get(key.str()); auto addr = extractAddress(allAddrs, i); transportContext->getPair(i)->connect(addr); } device_ = dev; transportContext_ = std::move(transportContext); } ContextFactory::ContextFactory(std::shared_ptr<::gloo::Context> backingContext) : backingContext_(backingContext) { // We make sure that we have a fully connected context for (auto i = 0; i < backingContext_->size; i++) { if (i == backingContext_->rank) { continue; } try { GLOO_ENFORCE( backingContext_->getPair(i) != nullptr, "Missing pair in backing context"); } catch(std::out_of_range& e) { GLOO_THROW("Backing context not fully connected"); } } auto slot = backingContext_->nextSlot(); auto notificationSlot = backingContext_->nextSlot(); // Create buffers we'll later use to communicate pair addresses recvData_.resize(backingContext_->size); sendData_.resize(backingContext_->size); recvBuffers_.resize(backingContext_->size); sendBuffers_.resize(backingContext_->size); recvNotificationData_.resize(backingContext_->size); sendNotificationData_.resize(backingContext_->size); recvNotificationBuffers_.resize(backingContext_->size); sendNotificationBuffers_.resize(backingContext_->size); for (auto i = 0; i < backingContext_->size; i++) { if (i == backingContext_->rank) { continue; } // Allocate memory for recv/send recvData_[i].resize(kMaxAddressSize); sendData_[i].resize(kMaxAddressSize); // Create pair auto& pair = backingContext_->getPair(i); // Create payload buffers { auto recvPtr = recvData_[i].data(); auto recvSize = recvData_[i].size(); recvBuffers_[i] = pair->createRecvBuffer(slot, recvPtr, recvSize); auto sendPtr = sendData_[i].data(); auto sendSize = sendData_[i].size(); sendBuffers_[i] = pair->createSendBuffer(slot, sendPtr, sendSize); } // Create notification buffers { auto recvPtr = &recvNotificationData_[i]; auto recvSize = sizeof(*recvPtr); recvNotificationBuffers_[i] = pair->createRecvBuffer(notificationSlot, recvPtr, recvSize); auto sendPtr = &sendNotificationData_[i]; auto sendSize = sizeof(*sendPtr); sendNotificationBuffers_[i] = pair->createSendBuffer(notificationSlot, sendPtr, sendSize); } } } std::shared_ptr<::gloo::Context> ContextFactory::makeContext( std::shared_ptr<transport::Device>& dev) { auto context = std::make_shared<Context>( backingContext_->rank, backingContext_->size); context->setTimeout(backingContext_->getTimeout()); // Assume it's the same for all pairs on a device size_t addressSize = 0; // Create pairs auto transportContext = dev->createContext(context->rank, context->size); transportContext->setTimeout(context->getTimeout()); for (auto i = 0; i < context->size; i++) { if (i == context->rank) { continue; } auto& pair = transportContext->createPair(i); auto address = pair->address().bytes(); addressSize = address.size(); // Send address of new pair to peer GLOO_ENFORCE_LE(addressSize, sendData_[i].size()); sendData_[i].assign(address.begin(), address.end()); sendBuffers_[i]->send(0, addressSize); } // Wait for remote addresses and connect peers for (auto i = 0; i < context->size; i++) { if (i == context->rank) { continue; } recvBuffers_[i]->waitRecv(); auto& data = recvData_[i]; auto address = std::vector<char>(data.begin(), data.begin() + addressSize); transportContext->getPair(i)->connect(address); // Notify peer that we've consumed the payload sendNotificationBuffers_[i]->send(); } // Wait for incoming notification from peers for (auto i = 0; i < context->size; i++) { if (i == context->rank) { continue; } recvNotificationBuffers_[i]->waitRecv(); } // Wait for outgoing notifications to be flushed for (auto i = 0; i < context->size; i++) { if (i == context->rank) { continue; } sendNotificationBuffers_[i]->waitSend(); } context->device_ = dev; context->transportContext_ = std::move(transportContext); return std::static_pointer_cast<::gloo::Context>(context); } } // namespace rendezvous } // namespace gloo