in src/group.h [547:606]
bool reduce(std::string_view name, uint32_t syncId, size_t sourcePeerIndex, ReduceVariant& data) noexcept {
std::shared_ptr<AllReduceOperation> h = ops.findPointer(name);
if (!h) {
log.debug("reduce: null h\n");
} else if (!h->started) {
log.debug("reduce: not started\n");
} else if (h->syncId != syncId) {
log.debug("reduce: syncId mismatch\n");
}
if (h && h->started && h->syncId == syncId) {
if (h->group->syncId != syncId) {
return false;
}
size_t myPeerIndex = h->myPeerIndex;
log.debug("%d reduce recv from %d\n", myPeerIndex, sourcePeerIndex);
int receiveIndex = sourcePeerIndex - myPeerIndex * 2;
if (receiveIndex != 0 && receiveIndex != 1) {
return false;
}
if (h->localData.index() != data.index()) {
return false;
}
if (h->hasReceived[receiveIndex].load(std::memory_order_relaxed)) {
return false;
}
bool done;
{
std::lock_guard l(h->opMutex);
if (h->hasReceived[receiveIndex].exchange(true, std::memory_order_relaxed)) {
return false;
}
h->op(h->localData, data);
done = h->hasReceived[receiveIndex ^ 1].load(std::memory_order_relaxed);
}
if (done) {
if (!h->hasSent.load(std::memory_order_relaxed) && !h->hasSent.exchange(true, std::memory_order_relaxed)) {
if (myPeerIndex == 0) {
log.debug("receive done, enter share mode!");
h->result = std::move(h->localData);
sendShare(&*h, h->peers.size());
h->flags |= 1;
h->doCallback();
} else {
log.debug("receive done, pass on!");
rpc->asyncCallback<void>(
h->peers[myPeerIndex / 2].name, "AllReduceService::reduce",
[h](rpc::Error* error) {
h->setException(std::move(*error));
h->doCallback();
},
h->name, h->syncId, myPeerIndex, h->localData);
}
}
}
return true;
} else {
return false;
}
}