void Communicator::valSparsAllReduce()

in src/io/communicator.cc [619:719]


void Communicator::valSparsAllReduce(size_t num, void *accumulation,
                                     Context *ctx) {
  CHECK_EQ(dataSize, sizeof(float))
      << "This function depends on thrust and support only fp32 currently";

  if (sparsInitialized == false) sparsInit();

  if (accumulation != NULL) {
    // add the previous accumulation
    cuda::add(num, static_cast<float *>(fusedSendBuff),
              static_cast<float *>(accumulation),
              static_cast<float *>(fusedSendBuff), ctx->c1);
    // backup the fusedSendBuff
    CUDA_CHECK(cudaMemcpyAsync(backupBuff, (const void *)fusedSendBuff,
                               sizeof(float) * num, cudaMemcpyDeviceToDevice,
                               ctx->c1));
  }

  // sparsification based on threshold
  cuda::sparsabs(num, threshold, static_cast<float *>(fusedSendBuff),
                 static_cast<float *>(fusedSendBuff), ctx->c1);

  // output the gradient accumulation
  if (accumulation != NULL)
    cuda::sub(num, static_cast<float *>(backupBuff),
              static_cast<float *>(fusedSendBuff),
              static_cast<float *>(accumulation), ctx->c1);

  // produce the index of the sparse array
  cuda::sparsindex(num, static_cast<float *>(fusedSendBuff), fusedIndex,
                   ctx->c1);

  // remove zero of index to become sprase array and get the num of non-zero nnz
  cuda::removezeroidx(num, fusedIndex, ctx->c1, nnz);

  CUDA_CHECK(cudaMemcpyAsync((void *)nnzGPU, (const void *)nnz, sizeof(int),
                             cudaMemcpyHostToDevice, ctx->c1));

  // all-gather all the nnz from different ranks
  NCCLCHECK(ncclAllGather((const void *)nnzGPU, (void *)nnzAllGPU, 1, ncclInt,
                          comm, ctx->c1));

  CUDA_CHECK(cudaMemcpyAsync((void *)nnzAll, (const void *)nnzAllGPU,
                             sizeof(int) * world_size, cudaMemcpyDeviceToHost,
                             ctx->c1));

  CUDA_CHECK(cudaStreamSynchronize(ctx->c1));

  int nnzMax = 0;
  for (int i = 0; i < world_size; i++)
    if (nnzAll[i] > nnzMax) nnzMax = nnzAll[i];

  // remove zero of values to become sprase array
  cuda::removezeroval(num, static_cast<float *>(fusedSendBuff), ctx->c1);

  CUDA_CHECK(cudaMemcpyAsync(sparsSendBuff, (const void *)fusedIndex,
                             sizeof(int) * (*nnz), cudaMemcpyDeviceToDevice,
                             ctx->c1));
  CUDA_CHECK(
      cudaMemcpyAsync((void *)(static_cast<float *>(sparsSendBuff) + (*nnz)),
                      (const void *)fusedSendBuff, sizeof(float) * (*nnz),
                      cudaMemcpyDeviceToDevice, ctx->c1));

  // wait for the memcpy to complete
  CUDA_CHECK(cudaEventRecord(event, ctx->c1));
  CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));

  // all-gather all the sparse gradients
  NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, sparsRecvBuff,
                          2 * nnzMax, ncclFloat, comm, ctx->s));

  // wait for the all-gather to complete
  CUDA_CHECK(cudaEventRecord(event, ctx->s));
  CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));

  // reduce the sparse gradients, firstly setting the sum buff value to zero
  CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), ctx->c2));

  size_t offset = 0;
  float alpha = 1.0;

  // add the spase gradent from each rank to the sum buff to finish the
  // all-reduce process
  CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c2));

  for (int i = 0; i < world_size; i++) {
    CUDA_CHECK(cudaMemcpyAsync(
        (void *)xInd,
        (const void *)(static_cast<float *>(sparsRecvBuff) + offset),
        sizeof(int) * nnzAll[i], cudaMemcpyDeviceToDevice, ctx->c2));
    offset += nnzAll[i];
    CUDA_CHECK(cudaMemcpyAsync(
        (void *)xVal,
        (const void *)(static_cast<float *>(sparsRecvBuff) + offset),
        sizeof(float) * nnzAll[i], cudaMemcpyDeviceToDevice, ctx->c2));
    offset += (2 * nnzMax - nnzAll[i]);
    CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzAll[i], &alpha, xVal,
                                  xInd, static_cast<float *>(fusedRecvBuff),
                                  CUSPARSE_INDEX_BASE_ONE));
  }
}