gloo/allgather.cc (68 lines of code) (raw):

/** * Copyright (c) 2018-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include "gloo/allgather.h" #include <array> #include <cstring> #include "gloo/common/logging.h" #include "gloo/types.h" namespace gloo { void allgather(AllgatherOptions& opts) { const auto& context = opts.context; transport::UnboundBuffer* in = opts.in.get(); transport::UnboundBuffer* out = opts.out.get(); const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag); // Sanity checks GLOO_ENFORCE(opts.elementSize > 0); const auto recvRank = (context->size + context->rank - 1) % context->size; GLOO_ENFORCE( recvRank == context->rank || context->getPair(recvRank), "missing connection between rank " + std::to_string(context->rank) + " (this process) and rank " + std::to_string(recvRank)); const auto sendRank = (context->size + context->rank + 1) % context->size; GLOO_ENFORCE( sendRank == context->rank || context->getPair(sendRank), "missing connection between rank " + std::to_string(context->rank) + " (this process) and rank " + std::to_string(sendRank)); if (in != nullptr) { GLOO_ENFORCE_EQ(out->size, in->size * context->size); } else { GLOO_ENFORCE_EQ(out->size % context->size, 0); } const size_t inBytes = out->size / context->size; const size_t outBytes = out->size; // If the input buffer is specified, this is NOT an in place operation, // and the output buffer needs to be primed with the input. if (in != nullptr) { memcpy( static_cast<uint8_t*>(out->ptr) + context->rank * in->size, static_cast<uint8_t*>(in->ptr), in->size); } // Short circuit if there is only a single process. if (context->size == 1) { return; } // The chunk size may not be divisible by 2; use dynamic lookup. std::array<size_t, 2> chunkSize; chunkSize[0] = inBytes / 2; chunkSize[1] = inBytes - chunkSize[0]; std::array<size_t, 2> chunkOffset; chunkOffset[0] = 0; chunkOffset[1] = chunkSize[0]; for (auto i = 0; i < (context->size - 1) * 2; i++) { const size_t sendSegment = context->size + context->rank - (i / 2); const size_t recvSegment = sendSegment - 1; size_t sendOffset = ((sendSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes; size_t recvOffset = ((recvSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes; size_t size = chunkSize[i & 0x1]; if (i < 2) { out->send(sendRank, slot, sendOffset, size); out->recv(recvRank, slot, recvOffset, size); continue; } // Wait for pending operations to complete to synchronize with the // previous iteration. Because we kick off two operations before // getting here we always wait for the next-to-last operation. out->waitSend(opts.timeout); out->waitRecv(opts.timeout); out->send(sendRank, slot, sendOffset, size); out->recv(recvRank, slot, recvOffset, size); } // Wait for completes for (auto i = 0; i < 2; i++) { out->waitSend(opts.timeout); out->waitRecv(opts.timeout); } } } // namespace gloo