void Communicator::topKSparsAllReduce()

in src/io/communicator.cc [721:807]


void Communicator::topKSparsAllReduce(size_t num, void *accumulation,
                                      Context *ctx) {
  CHECK_EQ(dataSize, sizeof(float))
      << "This function depends on thrust and support only fp32 currently";

  if (sparsInitialized == false) sparsInit();

  // use gradient accumulation
  if (accumulation != NULL) {
    // add the previous accumulation
    cuda::add(num, static_cast<float *>(fusedSendBuff),
              static_cast<float *>(accumulation),
              static_cast<float *>(fusedSendBuff), ctx->c1);
    // backup the fusedSendBuff
    CUDA_CHECK(cudaMemcpyAsync(backupBuff, (const void *)fusedSendBuff,
                               sizeof(float) * num, cudaMemcpyDeviceToDevice,
                               ctx->c1));
  }

  // generate an index and sort the fusedSendBuff from large to small values
  cuda::generateindex(num, fusedIndex, ctx->c1);
  cuda::sortbykey(num, static_cast<float *>(fusedSendBuff), fusedIndex,
                  ctx->c1);

  // determine the number of topK for communication
  int nnzMax = (int)ceil(threshold * num);

  // output the gradient accumulation
  float alpha = 1.0;
  if (accumulation != NULL) {
    CUDA_CHECK(cudaMemsetAsync(accumulation, 0, num * sizeof(float), ctx->c1));
    CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c1));
    CUSPARSE_CHECK(cusparseSaxpyi(
        cusparse_handle, nnzMax, &alpha, static_cast<float *>(fusedSendBuff),
        fusedIndex, static_cast<float *>(accumulation),
        CUSPARSE_INDEX_BASE_ONE));
    cuda::sub(num, static_cast<float *>(backupBuff),
              static_cast<float *>(accumulation),
              static_cast<float *>(accumulation), ctx->c1);
  }

  // the topK value and index will be sent
  CUDA_CHECK(cudaMemcpyAsync(sparsSendBuff, (const void *)fusedIndex,
                             sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice,
                             ctx->c1));
  CUDA_CHECK(
      cudaMemcpyAsync((void *)(static_cast<float *>(sparsSendBuff) + nnzMax),
                      (const void *)fusedSendBuff, sizeof(float) * nnzMax,
                      cudaMemcpyDeviceToDevice, ctx->c1));

  // wait for the memcpy to complete
  CUDA_CHECK(cudaEventRecord(event, ctx->c1));
  CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));

  // all-gather all the sparse gradients
  NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, (void *)sparsRecvBuff,
                          2 * nnzMax, ncclFloat, comm, ctx->s));

  // wait for the all-gather to complete
  CUDA_CHECK(cudaEventRecord(event, ctx->s));
  CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));

  // reduce the sparse gradients, firstly setting the sum buff value to zero
  CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), ctx->c2));

  size_t offset = 0;

  CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c2));

  // add the spase gradent from each rank to the sum buff to finish the
  // all-reduce process
  for (int i = 0; i < world_size; i++) {
    CUDA_CHECK(cudaMemcpyAsync(
        (void *)xInd,
        (const void *)(static_cast<float *>(sparsRecvBuff) + offset),
        sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c2));
    offset += nnzMax;
    CUDA_CHECK(cudaMemcpyAsync(
        (void *)xVal,
        (const void *)(static_cast<float *>(sparsRecvBuff) + offset),
        sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c2));
    offset += nnzMax;
    CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzMax, &alpha, xVal, xInd,
                                  static_cast<float *>(fusedRecvBuff),
                                  CUSPARSE_INDEX_BASE_ONE));
  }
}