gloo/cuda_allreduce_halving_doubling.h (86 lines of code) (raw):

/** * Copyright (c) 2017-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <math.h> #include <stddef.h> #include <string.h> #include "gloo/algorithm.h" #include "gloo/common/error.h" #include "gloo/cuda.h" #include "gloo/cuda_workspace.h" namespace gloo { template <typename T, typename W = CudaHostWorkspace<T> > class CudaAllreduceHalvingDoubling : public Algorithm { public: CudaAllreduceHalvingDoubling( const std::shared_ptr<Context>& context, const std::vector<T*>& ptrs, const int count, const std::vector<cudaStream_t>& streams = std::vector<cudaStream_t>(), bool pipelineBroadcastAndReduce = false); virtual ~CudaAllreduceHalvingDoubling() = default; virtual void run() override; protected: void initBinaryBlocks(); void devicePointerInit(); // Both workspace types have their own initialization function. template <typename U = W> void init( typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value, typename U::Pointer>::type* = 0); template <typename U = W> void init( typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value, typename U::Pointer>::type* = 0); template <typename U = W> void initReductionsAndBroadcasts( typename std::enable_if<std::is_same<U, CudaHostWorkspace<T> >::value, typename U::Pointer>::type* = 0); template <typename U = W> void initReductionsAndBroadcasts( typename std::enable_if<std::is_same<U, CudaDeviceWorkspace<T> >::value, typename U::Pointer>::type* = 0); std::vector<CudaDevicePointer<T> > devicePtrs_; std::vector<CudaStream> streams_; typename W::Pointer scratch_; CudaStream* scratchStream_; const int count_; const int bytes_; const size_t steps_; const size_t chunks_; const size_t chunkSize_; const size_t chunkBytes_; const CudaReductionFunction<T>* fn_; // offsets into the data buffer from which to send during the reduce-scatter // these become the offsets at which the process receives during the allgather // indexed by step std::vector<size_t> sendOffsets_; // offsets at which data is reduced during the reduce-scatter and sent from in // the allgather std::vector<size_t> recvOffsets_; std::vector<std::unique_ptr<transport::Buffer>> sendDataBufs_; std::vector<std::unique_ptr<transport::Buffer>> recvDataBufs_; std::unique_ptr<transport::Buffer> smallerBlockSendDataBuf_; std::unique_ptr<transport::Buffer> smallerBlockRecvDataBuf_; std::vector<std::unique_ptr<transport::Buffer>> largerBlockSendDataBufs_; std::vector<std::unique_ptr<transport::Buffer>> largerBlockRecvDataBufs_; std::vector<size_t> sendCounts_; std::vector<size_t> recvCounts_; size_t sendCountToLargerBlock_; int dummy_; std::vector<std::unique_ptr<transport::Buffer>> sendNotificationBufs_; std::vector<std::unique_ptr<transport::Buffer>> recvNotificationBufs_; std::unique_ptr<LocalOp<T>> reduceBeforeFirstSend_; std::unique_ptr<LocalOp<T>> reduceBeforeFirstRecv_; std::unique_ptr<LocalOp<T> > localReduceOp_; std::unique_ptr<LocalOp<T> > localBroadcastOp_; // buffer where data is received prior to being reduced typename W::Pointer recvBuf_; typename W::Pointer scratchPtrForFirstSend_; typename W::Pointer scratchPtrForFirstRecv_; std::vector<CudaDevicePointer<T>> devicePtrsForFirstSend_; std::vector<CudaDevicePointer<T>> devicePtrsForFirstRecv_; std::vector<typename W::Pointer> scratchPtrForBroadcast_; std::vector<std::vector<CudaDevicePointer<T>>> devicePtrsForBroadcast_; std::vector<std::unique_ptr<LocalOp<T>>> broadcastOps_; bool pipelined_; // for non-power-of-two number of processes, partition the processes into // binary blocks and keep track of which block each process is in, as well as // the adjoining larger and smaller blocks (with which communication will be // required) uint32_t offsetToMyBinaryBlock_; uint32_t myBinaryBlockSize_; uint32_t stepsWithinBlock_; uint32_t rankInBinaryBlock_; uint32_t nextSmallerBlockSize_; uint32_t nextLargerBlockSize_; int slotOffset_; }; } // namespace gloo