gloo/allreduce_ring.h (81 lines of code) (raw):
/**
* Copyright (c) 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <stddef.h>
#include <string.h>
#include "gloo/algorithm.h"
#include "gloo/context.h"
namespace gloo {
template <typename T>
class AllreduceRing : public Algorithm {
public:
AllreduceRing(
const std::shared_ptr<Context>& context,
const std::vector<T*>& ptrs,
const int count,
const ReductionFunction<T>* fn = ReductionFunction<T>::sum)
: Algorithm(context),
ptrs_(ptrs),
count_(count),
bytes_(count_ * sizeof(T)),
fn_(fn) {
inbox_ = static_cast<T*>(malloc(bytes_));
outbox_ = static_cast<T*>(malloc(bytes_));
if (this->contextSize_ == 1) {
return;
}
auto& leftPair = this->getLeftPair();
auto& rightPair = this->getRightPair();
auto slot = this->context_->nextSlot();
// Buffer to send to (rank+1).
sendDataBuf_ = rightPair->createSendBuffer(slot, outbox_, bytes_);
// Buffer that (rank-1) writes to.
recvDataBuf_ = leftPair->createRecvBuffer(slot, inbox_, bytes_);
// Dummy buffers for localized barrier.
// Before sending to the right, we only need to know that the node
// on the right is done using the inbox that's about to be written
// into. No need for a global barrier.
auto notificationSlot = this->context_->nextSlot();
sendNotificationBuf_ =
leftPair->createSendBuffer(notificationSlot, &dummy_, sizeof(dummy_));
recvNotificationBuf_ =
rightPair->createRecvBuffer(notificationSlot, &dummy_, sizeof(dummy_));
}
virtual ~AllreduceRing() {
if (inbox_ != nullptr) {
free(inbox_);
}
if (outbox_ != nullptr) {
free(outbox_);
}
}
void run() {
if (count_ == 0) {
return;
}
// Reduce specified pointers into ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
fn_->call(ptrs_[0], ptrs_[i], count_);
}
// Intialize outbox with locally reduced values
memcpy(outbox_, ptrs_[0], bytes_);
int numRounds = this->contextSize_ - 1;
for (int round = 0; round < numRounds; round++) {
// Initiate write to inbox of node on the right
sendDataBuf_->send();
// Wait for inbox write from node on the left
recvDataBuf_->waitRecv();
// Reduce
fn_->call(ptrs_[0], inbox_, count_);
// Wait for outbox write to complete
sendDataBuf_->waitSend();
// Prepare for next round if necessary
if (round < (numRounds - 1)) {
memcpy(outbox_, inbox_, bytes_);
}
// Send notification to node on the left that
// this node is ready for an inbox write.
sendNotificationBuf_->send();
// Wait for notification from node on the right
recvNotificationBuf_->waitRecv();
}
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
}
protected:
std::vector<T*> ptrs_;
const int count_;
const int bytes_;
const ReductionFunction<T>* fn_;
T* inbox_;
T* outbox_;
std::unique_ptr<transport::Buffer> sendDataBuf_;
std::unique_ptr<transport::Buffer> recvDataBuf_;
int dummy_;
std::unique_ptr<transport::Buffer> sendNotificationBuf_;
std::unique_ptr<transport::Buffer> recvNotificationBuf_;
};
} // namespace gloo