in horovod/common/ops/adasum/adasum.h [195:336]
void FusedAllreduce(std::vector<TensorTableEntry>& entries, T* grad_buffer,
T* recv_buffer, DataType horovod_datatype,
std::vector<int>& tensor_counts, int start_level,
Communicator_type communicator, int tag,
Communicator_type* reduction_comms,
HorovodGlobalState* global_state) {
int per_element_size =
global_state->controller->GetTypeSize(horovod_datatype);
int rank = GetLocalRankWithComm(communicator);
int size = GetSizeWithComm(communicator);
std::vector<std::vector<int>> nghrCountVec;
std::vector<double> normAndDots(tensor_counts.size() * 3 * 2);
int nearest_power_2 = 1;
for (nearest_power_2 = 1; (nearest_power_2 << 1) <= size;
nearest_power_2 = (nearest_power_2 << 1)) {
}
int level;
int nghrCountVec_index = 0;
int orgSize = size;
size = nearest_power_2;
int total_counts_sum = 0;
for (size_t i = 0; i < tensor_counts.size(); i++)
total_counts_sum += tensor_counts[i];
int myCount = total_counts_sum;
int comm_index;
for (level = 1, comm_index = 0; level < size;
level = (level << 1), comm_index++) {
if (level < start_level) {
continue;
}
int neighbor_rank = rank ^ level;
int nghrCount = 0;
int sendOffset = 0;
int recvOffset = 0;
int firstHalfMyCount = (myCount >> 1);
int secondHalfMyCount = myCount - firstHalfMyCount;
nghrCountVec.emplace_back();
nghrCountVec[nghrCountVec_index].resize(tensor_counts.size());
int myCountSoFar = 0;
int nghrCountSoFar = 0;
if ((rank & level) != 0) {
myCount = secondHalfMyCount;
nghrCount = firstHalfMyCount;
sendOffset = 0;
recvOffset = nghrCount;
for (size_t i = 0; i < tensor_counts.size(); i++) {
if (nghrCountSoFar <= nghrCount) {
if (nghrCountSoFar + tensor_counts[i] <= nghrCount) {
nghrCountVec[nghrCountVec_index][i] = tensor_counts[i];
tensor_counts[i] = 0;
} else {
nghrCountVec[nghrCountVec_index][i] =
nghrCount - nghrCountSoFar; // should not be negative
tensor_counts[i] =
tensor_counts[i] -
(nghrCount - nghrCountSoFar); // should not be negative
}
} else {
nghrCountVec[nghrCountVec_index][i] = 0;
}
nghrCountSoFar += nghrCountVec[nghrCountVec_index][i];
myCountSoFar += tensor_counts[i];
}
} else {
myCount = firstHalfMyCount;
nghrCount = secondHalfMyCount;
sendOffset = myCount;
recvOffset = 0;
for (size_t i = 0; i < tensor_counts.size(); i++) {
if (myCountSoFar <= myCount) {
if (myCountSoFar + tensor_counts[i] <= myCount) {
nghrCountVec[nghrCountVec_index][i] = 0;
} else {
nghrCountVec[nghrCountVec_index][i] =
tensor_counts[i] -
(myCount - myCountSoFar); // should not be negative
tensor_counts[i] =
myCount - myCountSoFar; // should not be negative
}
} else {
nghrCountVec[nghrCountVec_index][i] = tensor_counts[i];
tensor_counts[i] = 0;
}
nghrCountSoFar += nghrCountVec[nghrCountVec_index][i];
myCountSoFar += tensor_counts[i];
}
}
nghrCountVec_index++;
this->PointToPointSendRecv(
(char*)(&grad_buffer[sendOffset]), nghrCount * per_element_size,
(char*)(&recv_buffer[recvOffset]), myCount * per_element_size,
horovod_datatype, neighbor_rank, tag, communicator, global_state);
if ((rank & level) != 0) {
grad_buffer = &grad_buffer[nghrCount];
recv_buffer = &recv_buffer[nghrCount];
}
FusedPairwiseReduceWithComm(
entries, (uint8_t*)grad_buffer, (uint8_t*)recv_buffer,
horovod_datatype, tensor_counts, tag, reduction_comms[comm_index],
(rank & level) == 0, normAndDots, global_state);
}
for (level = (size >> 1); level > 0; level = (level >> 1)) {
if (level < start_level) {
continue;
}
int neighbor_rank = rank ^ level;
nghrCountVec_index--;
int nghrCount = 0;
for (size_t i = 0; i < tensor_counts.size(); i++) {
nghrCount += nghrCountVec[nghrCountVec_index][i];
tensor_counts[i] += nghrCountVec[nghrCountVec_index][i];
}
if ((rank & level) == 0) {
recv_buffer = &grad_buffer[myCount];
} else {
recv_buffer = &grad_buffer[-nghrCount];
}
this->PointToPointSendRecv(grad_buffer, myCount * per_element_size,
recv_buffer, nghrCount * per_element_size,
horovod_datatype, neighbor_rank, tag,
communicator, global_state);
if ((rank & level) != 0) {
grad_buffer = &grad_buffer[-nghrCount];
}
myCount += nghrCount;
}
size = orgSize;
}