gloo/alltoall.h (47 lines of code) (raw):

/** * Copyright (c) 2018-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include "gloo/common/logging.h" #include "gloo/context.h" #include "gloo/transport/unbound_buffer.h" namespace gloo { class AlltoallOptions { public: explicit AlltoallOptions(const std::shared_ptr<Context>& context) : context(context), timeout(context->getTimeout()) {} template <typename T> void setInput(std::unique_ptr<transport::UnboundBuffer> buf) { elementSize = sizeof(T); in = std::move(buf); } template <typename T> void setInput(T* ptr, size_t elements) { elementSize = sizeof(T); in = context->createUnboundBuffer(ptr, elements * sizeof(T)); } template <typename T> void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) { elementSize = sizeof(T); out = std::move(buf); } template <typename T> void setOutput(T* ptr, size_t elements) { elementSize = sizeof(T); out = context->createUnboundBuffer(ptr, elements * sizeof(T)); } void setTag(uint32_t tag) { this->tag = tag; } void setTimeout(std::chrono::milliseconds timeout) { GLOO_ENFORCE(timeout.count() > 0); this->timeout = timeout; } protected: std::shared_ptr<Context> context; std::unique_ptr<transport::UnboundBuffer> in; std::unique_ptr<transport::UnboundBuffer> out; // Number of bytes per element. size_t elementSize = 0; // Tag for this operation. // Must be unique across operations executing in parallel. uint32_t tag = 0; // End-to-end timeout for this operation. std::chrono::milliseconds timeout; friend void alltoall(AlltoallOptions&); }; void alltoall(AlltoallOptions& opts); } // namespace gloo