in gloo/cuda_broadcast_one_to_all.cc [93:162]
void CudaBroadcastOneToAll<T, W>::run() {
if (contextSize_ == 1) {
if (localBroadcastOp_) {
localBroadcastOp_->runAsync();
if (synchronizeDeviceOutputs_) {
localBroadcastOp_->wait();
}
}
return;
}
if (contextRank_ == rootRank_) {
CudaStream& stream = streams_[rootPointerRank_];
// Copy device buffer to host
stream.copyAsync(scratch_, devicePtrs_[rootPointerRank_]);
stream.wait();
// Fire off send operations after receiving clear to send
for (auto i = 0; i < contextSize_; i++) {
if (i == contextRank_) {
continue;
}
sender_[i]->clearToSendBuffer->waitRecv();
sender_[i]->sendBuffer->send();
}
// Broadcast locally while sends are happening
if (localBroadcastOp_) {
localBroadcastOp_->runAsync();
if (synchronizeDeviceOutputs_) {
localBroadcastOp_->wait();
}
}
// Wait for all send operations to complete
for (auto i = 0; i < contextSize_; i++) {
if (i == contextRank_) {
continue;
}
sender_[i]->sendBuffer->waitSend();
}
} else {
CudaStream& stream = streams_[rootPointerRank_];
// Ensure previous H2D copy is complete before notifying the sender
// NOTE: this only waits for last copyAsync, not for the whole stream
stream.wait();
receiver_->clearToSendBuffer->send();
receiver_->recvBuffer->waitRecv();
// Copy host buffer to device
stream.copyAsync(devicePtrs_[rootPointerRank_], scratch_);
// Broadcast locally after receiving from root
if (localBroadcastOp_) {
// Since broadcast synchronizes on root pointer, there is no
// need to explicity wait for the memcpy to complete.
localBroadcastOp_->runAsync();
if (synchronizeDeviceOutputs_) {
localBroadcastOp_->wait();
}
} else {
// Wait for memcpy to complete
if (synchronizeDeviceOutputs_) {
stream.wait();
}
}
}
}