void FusedAllreduce()

in horovod/common/ops/adasum/adasum.h [195:336]


  void FusedAllreduce(std::vector<TensorTableEntry>& entries, T* grad_buffer,
                      T* recv_buffer, DataType horovod_datatype,
                      std::vector<int>& tensor_counts, int start_level,
                      Communicator_type communicator, int tag,
                      Communicator_type* reduction_comms,
                      HorovodGlobalState* global_state) {
    int per_element_size =
        global_state->controller->GetTypeSize(horovod_datatype);
    int rank = GetLocalRankWithComm(communicator);
    int size = GetSizeWithComm(communicator);

    std::vector<std::vector<int>> nghrCountVec;
    std::vector<double> normAndDots(tensor_counts.size() * 3 * 2);

    int nearest_power_2 = 1;
    for (nearest_power_2 = 1; (nearest_power_2 << 1) <= size;
         nearest_power_2 = (nearest_power_2 << 1)) {
    }
    int level;

    int nghrCountVec_index = 0;
    int orgSize = size;
    size = nearest_power_2;

    int total_counts_sum = 0;
    for (size_t i = 0; i < tensor_counts.size(); i++)
      total_counts_sum += tensor_counts[i];
    int myCount = total_counts_sum;
    int comm_index;
    for (level = 1, comm_index = 0; level < size;
         level = (level << 1), comm_index++) {
      if (level < start_level) {
        continue;
      }

      int neighbor_rank = rank ^ level;
      int nghrCount = 0;
      int sendOffset = 0;
      int recvOffset = 0;
      int firstHalfMyCount = (myCount >> 1);
      int secondHalfMyCount = myCount - firstHalfMyCount;

      nghrCountVec.emplace_back();
      nghrCountVec[nghrCountVec_index].resize(tensor_counts.size());

      int myCountSoFar = 0;
      int nghrCountSoFar = 0;
      if ((rank & level) != 0) {
        myCount = secondHalfMyCount;
        nghrCount = firstHalfMyCount;
        sendOffset = 0;
        recvOffset = nghrCount;

        for (size_t i = 0; i < tensor_counts.size(); i++) {
          if (nghrCountSoFar <= nghrCount) {
            if (nghrCountSoFar + tensor_counts[i] <= nghrCount) {
              nghrCountVec[nghrCountVec_index][i] = tensor_counts[i];
              tensor_counts[i] = 0;
            } else {
              nghrCountVec[nghrCountVec_index][i] =
                  nghrCount - nghrCountSoFar; // should not be negative
              tensor_counts[i] =
                  tensor_counts[i] -
                  (nghrCount - nghrCountSoFar); // should not be negative
            }
          } else {
            nghrCountVec[nghrCountVec_index][i] = 0;
          }
          nghrCountSoFar += nghrCountVec[nghrCountVec_index][i];
          myCountSoFar += tensor_counts[i];
        }
      } else {
        myCount = firstHalfMyCount;
        nghrCount = secondHalfMyCount;
        sendOffset = myCount;
        recvOffset = 0;

        for (size_t i = 0; i < tensor_counts.size(); i++) {
          if (myCountSoFar <= myCount) {
            if (myCountSoFar + tensor_counts[i] <= myCount) {
              nghrCountVec[nghrCountVec_index][i] = 0;
            } else {
              nghrCountVec[nghrCountVec_index][i] =
                  tensor_counts[i] -
                  (myCount - myCountSoFar); // should not be negative
              tensor_counts[i] =
                  myCount - myCountSoFar; // should not be negative
            }
          } else {
            nghrCountVec[nghrCountVec_index][i] = tensor_counts[i];
            tensor_counts[i] = 0;
          }
          nghrCountSoFar += nghrCountVec[nghrCountVec_index][i];
          myCountSoFar += tensor_counts[i];
        }
      }

      nghrCountVec_index++;

      this->PointToPointSendRecv(
          (char*)(&grad_buffer[sendOffset]), nghrCount * per_element_size,
          (char*)(&recv_buffer[recvOffset]), myCount * per_element_size,
          horovod_datatype, neighbor_rank, tag, communicator, global_state);
      if ((rank & level) != 0) {
        grad_buffer = &grad_buffer[nghrCount];
        recv_buffer = &recv_buffer[nghrCount];
      }
      FusedPairwiseReduceWithComm(
          entries, (uint8_t*)grad_buffer, (uint8_t*)recv_buffer,
          horovod_datatype, tensor_counts, tag, reduction_comms[comm_index],
          (rank & level) == 0, normAndDots, global_state);
    }

    for (level = (size >> 1); level > 0; level = (level >> 1)) {
      if (level < start_level) {
        continue;
      }
      int neighbor_rank = rank ^ level;

      nghrCountVec_index--;
      int nghrCount = 0;
      for (size_t i = 0; i < tensor_counts.size(); i++) {
        nghrCount += nghrCountVec[nghrCountVec_index][i];
        tensor_counts[i] += nghrCountVec[nghrCountVec_index][i];
      }

      if ((rank & level) == 0) {
        recv_buffer = &grad_buffer[myCount];
      } else {
        recv_buffer = &grad_buffer[-nghrCount];
      }
      this->PointToPointSendRecv(grad_buffer, myCount * per_element_size,
                                 recv_buffer, nghrCount * per_element_size,
                                 horovod_datatype, neighbor_rank, tag,
                                 communicator, global_state);
      if ((rank & level) != 0) {
        grad_buffer = &grad_buffer[-nghrCount];
      }
      myCount += nghrCount;
    }
    size = orgSize;
  }