gloo/allreduce.h (119 lines of code) (raw):

/** * Copyright (c) 2018-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <functional> #include <memory> #include <vector> #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" namespace gloo { namespace detail { struct AllreduceOptionsImpl { // This type describes the function to use for element wise reduction. // // Its arguments are: // 1. non-const output pointer // 2. const input pointer 1 (may be equal to 1) // 3. const input pointer 2 (may be equal to 1) // 4. number of elements to reduce. // // Note that this function is not strictly typed and takes void pointers. // This is specifically done to avoid the need for a templated options class // and templated algorithm implementations. We found this adds very little // value for the increase in compilation time and code size. // using Func = std::function<void(void*, const void*, const void*, size_t)>; enum Algorithm { UNSPECIFIED = 0, RING = 1, BCUBE = 2, }; explicit AllreduceOptionsImpl(const std::shared_ptr<Context>& context) : context(context), timeout(context->getTimeout()), algorithm(UNSPECIFIED) {} std::shared_ptr<Context> context; // End-to-end timeout for this operation. std::chrono::milliseconds timeout; // Algorithm selection. Algorithm algorithm; // Input and output buffers. // The output is used as input if input is not specified. std::vector<std::unique_ptr<transport::UnboundBuffer>> in; std::vector<std::unique_ptr<transport::UnboundBuffer>> out; // Number of elements. size_t elements = 0; // Number of bytes per element. size_t elementSize = 0; // Reduction function. Func reduce; // Tag for this operation. // Must be unique across operations executing in parallel. uint32_t tag = 0; // This is the maximum size of each I/O operation (send/recv) of which // two are in flight at all times. A smaller value leads to more // overhead and a larger value leads to poor cache behavior. static constexpr size_t kMaxSegmentSize = 1024 * 1024; // Internal use only. This is used to exercise code paths where we // have more than 2 segments per rank without making the tests slow // (because they would require millions of elements if the default // were not configurable). size_t maxSegmentSize = kMaxSegmentSize; }; } // namespace detail class AllreduceOptions { public: using Func = detail::AllreduceOptionsImpl::Func; using Algorithm = detail::AllreduceOptionsImpl::Algorithm; explicit AllreduceOptions(const std::shared_ptr<Context>& context) : impl_(context) {} void setAlgorithm(Algorithm algorithm) { impl_.algorithm = algorithm; } template <typename T> void setInput(std::unique_ptr<transport::UnboundBuffer> buf) { std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1); bufs[0] = std::move(buf); setInputs<T>(std::move(bufs)); } template <typename T> void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) { impl_.elements = bufs[0]->size / sizeof(T); impl_.elementSize = sizeof(T); impl_.in = std::move(bufs); } template <typename T> void setInput(T* ptr, size_t elements) { setInputs(&ptr, 1, elements); } template <typename T> void setInputs(std::vector<T*> ptrs, size_t elements) { setInputs(ptrs.data(), ptrs.size(), elements); } template <typename T> void setInputs(T** ptrs, size_t len, size_t elements) { impl_.elements = elements; impl_.elementSize = sizeof(T); impl_.in.reserve(len); for (size_t i = 0; i < len; i++) { impl_.in.push_back( impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T))); } } template <typename T> void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) { std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1); bufs[0] = std::move(buf); setOutputs<T>(std::move(bufs)); } template <typename T> void setOutputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) { impl_.elements = bufs[0]->size / sizeof(T); impl_.elementSize = sizeof(T); impl_.out = std::move(bufs); } template <typename T> void setOutput(T* ptr, size_t elements) { setOutputs(&ptr, 1, elements); } template <typename T> void setOutputs(std::vector<T*> ptrs, size_t elements) { setOutputs(ptrs.data(), ptrs.size(), elements); } template <typename T> void setOutputs(T** ptrs, size_t len, size_t elements) { impl_.elements = elements; impl_.elementSize = sizeof(T); impl_.out.reserve(len); for (size_t i = 0; i < len; i++) { impl_.out.push_back( impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T))); } } void setReduceFunction(Func fn) { impl_.reduce = fn; } void setTag(uint32_t tag) { impl_.tag = tag; } void setMaxSegmentSize(size_t maxSegmentSize) { impl_.maxSegmentSize = maxSegmentSize; } void setTimeout(std::chrono::milliseconds timeout) { impl_.timeout = timeout; } protected: detail::AllreduceOptionsImpl impl_; friend void allreduce(const AllreduceOptions&); }; void allreduce(const AllreduceOptions& opts); } // namespace gloo