void Communicator::fusedSynch()

in src/io/communicator.cc [212:302]


void Communicator::fusedSynch(vector<Tensor> &t, bool send) {
  CHECK_GT(t.size(), 0);

  generateBlocks(t);

  if (t[0].data_type() == kFloat16) {
    ncclType = ncclHalf;
    dataSize = sizeof(__half);
  } else {
    ncclType = ncclFloat;
    dataSize = sizeof(float);
  }

  if (!send) {
    // buffer the tensors
    device_->Exec(
        [this, t](Context *ctx) mutable {
          // record the event of the default cuda stream and follow it
          CUDA_CHECK(cudaEventRecord(event, ctx->stream));
          CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
        },
        prev_blocks_, prev_blocks_, "Waiting");

    device_->Exec(
        [this, t](Context *ctx) mutable {
          // memory copy to fusedBuff
          for (size_t i = 0; i < t.size(); i++) {
            if (t[0].data_type() == kFloat16) {
              offsetPointer = (void *)(static_cast<__half *>(fusedSendBuff) +
                                       sendBuffOffset);
            } else {
              offsetPointer = (void *)(static_cast<float *>(fusedSendBuff) +
                                       sendBuffOffset);
            }
            CUDA_CHECK(cudaMemcpyAsync(
                (void *)offsetPointer,
                (const void *)t[i].block()->mutable_data(),
                t[i].Size() * dataSize, cudaMemcpyDeviceToDevice, ctx->c1));
            sendBuffOffset += t[i].Size();
          }
        },
        prev_blocks_, blocks_, "Dist_c1_fusedSynch_filling");

  } else {
    // send the tensors in the buffer
    device_->Exec(
        [this](Context *ctx) mutable {
          // wait for the memcpy to complete
          CUDA_CHECK(cudaEventRecord(event, ctx->c1));
          CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
        },
        prev_blocks_, prev_blocks_, "Waiting");

    device_->Exec(
        [this](Context *ctx) mutable {
          allReduce((int)sendBuffOffset, fusedSendBuff, fusedRecvBuff, ncclType,
                    ctx);
          sendBuffOffset = 0;
        },
        prev_blocks_, blocks_, "Dist_s_fusedSynch_allreduce");

    device_->Exec(
        [this](Context *ctx) mutable {
          // wait for the allreduce to complete
          CUDA_CHECK(cudaEventRecord(event, ctx->s));
          CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
        },
        blocks_, blocks_, "Waiting");

    device_->Exec(
        [this, t](Context *ctx) mutable {
          // copy data back to tensors after allreduce
          size_t offset = 0;
          for (size_t i = 0; i < t.size(); i++) {
            if (t[0].data_type() == kFloat16) {
              offsetPointer =
                  (void *)(static_cast<__half *>(fusedRecvBuff) + offset);
            } else {
              offsetPointer =
                  (void *)(static_cast<float *>(fusedRecvBuff) + offset);
            }
            CUDA_CHECK(cudaMemcpyAsync((void *)t[i].block()->mutable_data(),
                                       (const void *)offsetPointer,
                                       t[i].Size() * dataSize,
                                       cudaMemcpyDeviceToDevice, ctx->c1));
            offset += t[i].Size();
          }
        },
        blocks_, blocks_, "Dist_c1_fusedSynch_copyBackToTensor");
  }
}