in src/io/communicator.cc [721:807]
void Communicator::topKSparsAllReduce(size_t num, void *accumulation,
Context *ctx) {
CHECK_EQ(dataSize, sizeof(float))
<< "This function depends on thrust and support only fp32 currently";
if (sparsInitialized == false) sparsInit();
// use gradient accumulation
if (accumulation != NULL) {
// add the previous accumulation
cuda::add(num, static_cast<float *>(fusedSendBuff),
static_cast<float *>(accumulation),
static_cast<float *>(fusedSendBuff), ctx->c1);
// backup the fusedSendBuff
CUDA_CHECK(cudaMemcpyAsync(backupBuff, (const void *)fusedSendBuff,
sizeof(float) * num, cudaMemcpyDeviceToDevice,
ctx->c1));
}
// generate an index and sort the fusedSendBuff from large to small values
cuda::generateindex(num, fusedIndex, ctx->c1);
cuda::sortbykey(num, static_cast<float *>(fusedSendBuff), fusedIndex,
ctx->c1);
// determine the number of topK for communication
int nnzMax = (int)ceil(threshold * num);
// output the gradient accumulation
float alpha = 1.0;
if (accumulation != NULL) {
CUDA_CHECK(cudaMemsetAsync(accumulation, 0, num * sizeof(float), ctx->c1));
CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c1));
CUSPARSE_CHECK(cusparseSaxpyi(
cusparse_handle, nnzMax, &alpha, static_cast<float *>(fusedSendBuff),
fusedIndex, static_cast<float *>(accumulation),
CUSPARSE_INDEX_BASE_ONE));
cuda::sub(num, static_cast<float *>(backupBuff),
static_cast<float *>(accumulation),
static_cast<float *>(accumulation), ctx->c1);
}
// the topK value and index will be sent
CUDA_CHECK(cudaMemcpyAsync(sparsSendBuff, (const void *)fusedIndex,
sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice,
ctx->c1));
CUDA_CHECK(
cudaMemcpyAsync((void *)(static_cast<float *>(sparsSendBuff) + nnzMax),
(const void *)fusedSendBuff, sizeof(float) * nnzMax,
cudaMemcpyDeviceToDevice, ctx->c1));
// wait for the memcpy to complete
CUDA_CHECK(cudaEventRecord(event, ctx->c1));
CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
// all-gather all the sparse gradients
NCCLCHECK(ncclAllGather((const void *)sparsSendBuff, (void *)sparsRecvBuff,
2 * nnzMax, ncclFloat, comm, ctx->s));
// wait for the all-gather to complete
CUDA_CHECK(cudaEventRecord(event, ctx->s));
CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));
// reduce the sparse gradients, firstly setting the sum buff value to zero
CUDA_CHECK(cudaMemsetAsync(fusedRecvBuff, 0, num * sizeof(float), ctx->c2));
size_t offset = 0;
CUSPARSE_CHECK(cusparseSetStream(cusparse_handle, ctx->c2));
// add the spase gradent from each rank to the sum buff to finish the
// all-reduce process
for (int i = 0; i < world_size; i++) {
CUDA_CHECK(cudaMemcpyAsync(
(void *)xInd,
(const void *)(static_cast<float *>(sparsRecvBuff) + offset),
sizeof(int) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c2));
offset += nnzMax;
CUDA_CHECK(cudaMemcpyAsync(
(void *)xVal,
(const void *)(static_cast<float *>(sparsRecvBuff) + offset),
sizeof(float) * nnzMax, cudaMemcpyDeviceToDevice, ctx->c2));
offset += nnzMax;
CUSPARSE_CHECK(cusparseSaxpyi(cusparse_handle, nnzMax, &alpha, xVal, xInd,
static_cast<float *>(fusedRecvBuff),
CUSPARSE_INDEX_BASE_ONE));
}
}