void Communicator::fusedSynchHalf()

in src/io/communicator.cc [330:414]


void Communicator::fusedSynchHalf(vector<Tensor> &t, bool send) {
  CHECK_EQ(t[0].data_type(), kFloat32)
      << "This function is only available for input tensor precision 32 bit, "
         "which are converted into 16 bits before transmit";

  CHECK_GT(t.size(), 0);

  generateBlocks(t);

  if (halfInitialized == false) halfInit();

  if (!send) {
    // buffer the tensors and convert them into half
    device_->Exec(
        [this](Context *ctx) mutable {
          // record the event of the default cuda stream and follow it
          CUDA_CHECK(cudaEventRecord(event, ctx->stream));
          CUDA_CHECK(cudaStreamWaitEvent(ctx->c1, event, 0));
        },
        prev_blocks_, prev_blocks_, "Waiting");
    device_->Exec(
        [this, t](Context *ctx) mutable {
          size_t offset = 0;
          // memory copy to fusedBuff
          for (size_t i = 0; i < t.size(); i++) {
            CUDA_CHECK(cudaMemcpyAsync(
                (void *)(static_cast<float *>(fusedSendBuff) + sendBuffOffset),
                (const void *)t[i].block()->mutable_data(),
                t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice,
                ctx->c1));
            sendBuffOffset += t[i].Size();
            offset += t[i].Size();
          }
        },
        prev_blocks_, blocks_, "Dist_c1_fusedSynchHalf_filling");
  } else {
    // send the tensors in the buffer
    device_->Exec(
        [this](Context *ctx) mutable {
          cuda::float2half(sendBuffOffset, static_cast<float *>(fusedSendBuff),
                           static_cast<__half *>(fusedSendBuffHalf), ctx->c1);
        },
        prev_blocks_, blocks_, "Dist_c1_fusedSynchHalf_float2half");
    device_->Exec(
        [this](Context *ctx) mutable {
          // wait for the memcpy to complete
          CUDA_CHECK(cudaEventRecord(event, ctx->c1));
          CUDA_CHECK(cudaStreamWaitEvent(ctx->s, event, 0));
        },
        blocks_, blocks_, "Waiting");
    device_->Exec(
        [this](Context *ctx) mutable {
          allReduce((int)sendBuffOffset, fusedSendBuffHalf, fusedRecvBuffHalf,
                    ncclHalf, ctx);
        },
        blocks_, blocks_, "Dist_s_fusedSynchHalf_allreduce");
    device_->Exec(
        [this](Context *ctx) mutable {
          // wait for the allreduce to complete
          CUDA_CHECK(cudaEventRecord(event, ctx->s));
          CUDA_CHECK(cudaStreamWaitEvent(ctx->c2, event, 0));
        },
        blocks_, blocks_, "Waiting");
    device_->Exec(
        [this, t](Context *ctx) mutable {
          cuda::half2float(sendBuffOffset,
                           static_cast<__half *>(fusedRecvBuffHalf),
                           static_cast<float *>(fusedRecvBuff), ctx->c2);

          sendBuffOffset = 0;

          // copy data back to tensors after allreduce
          size_t offset = 0;
          for (size_t i = 0; i < t.size(); i++) {
            CUDA_CHECK(cudaMemcpyAsync(
                (void *)t[i].block()->mutable_data(),
                (const void *)(static_cast<float *>(fusedRecvBuff) + offset),
                t[i].Size() * sizeof(float), cudaMemcpyDeviceToDevice,
                ctx->c2));
            offset += t[i].Size();
          }
        },
        blocks_, blocks_, "Dist_c2_fusedSynchHalf_half2floatcopy");
  }
}