static void clientPingPongNonBlock()

in tensorpipe/benchmark/benchmark_pipe.cc [422:636]


static void clientPingPongNonBlock(
    std::shared_ptr<Pipe> pipe,
    int& numWarmUps,
    int& numRoundTrips,
    std::promise<void>& doneProm,
    Data& data,
    MultiDeviceMeasurements& measurements) {
#if USE_NCCL
  for (int iterIdx = 0; iterIdx < numWarmUps + numRoundTrips; iterIdx++) {
    if (iterIdx >= numWarmUps) {
      measurements.cpu.markStart();
      if ((iterIdx - numWarmUps) % data.cudaSyncPeriod == 0) {
        measurements.cuda.markStart();
      }
    }
    TP_NCCL_CHECK(ncclSend(
        data.expectedCudaTensor[0].get(),
        data.tensorSize,
        ncclInt8,
        0,
        data.ncclComm.get(),
        data.cudaStream.get()));
    TP_NCCL_CHECK(ncclRecv(
        data.temporaryCudaTensor[0].get(),
        data.tensorSize,
        ncclInt8,
        0,
        data.ncclComm.get(),
        data.cudaStream.get()));
    if (iterIdx >= numWarmUps) {
      measurements.cpu.markStop();
      if ((iterIdx - numWarmUps + 1) % data.cudaSyncPeriod == 0) {
        TP_CUDA_CHECK(cudaStreamSynchronize(data.cudaStream.get()));
        measurements.cuda.markStop(data.cudaSyncPeriod);
      }
    }
  }
  printMultiDeviceMeasurements(measurements, data.payloadSize);
  doneProm.set_value();
  return;
#endif // USE_NCCL
  if (numWarmUps == 0) {
    measurements.cpu.markStart();
    if (numRoundTrips % data.cudaSyncPeriod == 0) {
      measurements.cuda.markStart();
    }
  }
  Message message;
  message.metadata = data.expectedMetadata;
  if (data.payloadSize > 0) {
    for (size_t payloadIdx = 0; payloadIdx < data.numPayloads; payloadIdx++) {
      Message::Payload payload;
      payload.data = data.expectedPayload[payloadIdx].get();
      payload.length = data.payloadSize;
      message.payloads.push_back(std::move(payload));
    }
  } else {
    TP_DCHECK_EQ(message.payloads.size(), 0);
  }
  if (data.tensorSize > 0) {
    for (size_t tensorIdx = 0; tensorIdx < data.numTensors; tensorIdx++) {
      Message::Tensor tensor;
      tensor.length = data.tensorSize;
      if (data.tensorType == TensorType::kCpu) {
        tensor.buffer =
            CpuBuffer{.ptr = data.expectedCpuTensor[tensorIdx].get()};
        tensor.targetDevice = Device(kCpuDeviceType, 0);
      } else if (data.tensorType == TensorType::kCuda) {
        tensor.buffer = CudaBuffer{
            .ptr = data.expectedCudaTensor[tensorIdx].get(),
            .stream = data.cudaStream.get(),
        };
        tensor.targetDevice = Device(kCudaDeviceType, 0);
      } else {
        TP_THROW_ASSERT() << "Unknown tensor type";
      }
      message.tensors.push_back(std::move(tensor));
    }
  } else {
    TP_DCHECK_EQ(message.tensors.size(), 0);
  }
  pipe->write(
      std::move(message),
      [pipe, &numWarmUps, &numRoundTrips, &doneProm, &data, &measurements](
          const Error& error) {
        TP_THROW_ASSERT_IF(error) << error.what();
        pipe->readDescriptor([pipe,
                              &numWarmUps,
                              &numRoundTrips,
                              &doneProm,
                              &data,
                              &measurements](
                                 const Error& error, Descriptor descriptor) {
          TP_THROW_ASSERT_IF(error) << error.what();

          Allocation allocation;
          TP_DCHECK_EQ(descriptor.metadata, data.expectedMetadata);
          if (data.payloadSize > 0) {
            TP_DCHECK_EQ(descriptor.payloads.size(), data.numPayloads);
            allocation.payloads.resize(data.numPayloads);
            for (size_t payloadIdx = 0; payloadIdx < data.numPayloads;
                 payloadIdx++) {
              TP_DCHECK_EQ(
                  descriptor.payloads[payloadIdx].metadata,
                  data.expectedPayloadMetadata[payloadIdx]);
              TP_DCHECK_EQ(
                  descriptor.payloads[payloadIdx].length, data.payloadSize);
              allocation.payloads[payloadIdx].data =
                  data.temporaryPayload[payloadIdx].get();
            }
          } else {
            TP_DCHECK_EQ(descriptor.payloads.size(), 0);
          }
          if (data.tensorSize > 0) {
            TP_DCHECK_EQ(descriptor.tensors.size(), data.numTensors);
            allocation.tensors.resize(data.numTensors);
            for (size_t tensorIdx = 0; tensorIdx < data.numTensors;
                 tensorIdx++) {
              TP_DCHECK_EQ(
                  descriptor.tensors[tensorIdx].metadata,
                  data.expectedTensorMetadata[tensorIdx]);
              TP_DCHECK_EQ(
                  descriptor.tensors[tensorIdx].length, data.tensorSize);
              if (data.tensorType == TensorType::kCpu) {
                allocation.tensors[tensorIdx].buffer = CpuBuffer{
                    .ptr = data.temporaryCpuTensor[tensorIdx].get(),
                };
              } else if (data.tensorType == TensorType::kCuda) {
                allocation.tensors[tensorIdx].buffer = CudaBuffer{
                    .ptr = data.temporaryCudaTensor[tensorIdx].get(),
                    .stream = data.cudaStream.get(),
                };
              } else {
                TP_THROW_ASSERT() << "Unknown tensor type";
              }
            }
          } else {
            TP_DCHECK_EQ(descriptor.tensors.size(), 0);
          }
          pipe->read(
              allocation,
              [pipe,
               &numWarmUps,
               &numRoundTrips,
               &doneProm,
               &data,
               &measurements,
               descriptor{std::move(descriptor)},
               allocation](const Error& error) {
                if (numWarmUps == 0) {
                  measurements.cpu.markStop();
                  if ((numRoundTrips - 1) % data.cudaSyncPeriod == 0) {
                    TP_CUDA_CHECK(cudaStreamSynchronize(data.cudaStream.get()));
                    measurements.cuda.markStop(data.cudaSyncPeriod);
                  }
                }
                TP_THROW_ASSERT_IF(error) << error.what();
                if (data.payloadSize > 0) {
                  TP_DCHECK_EQ(allocation.payloads.size(), data.numPayloads);
                  for (size_t payloadIdx = 0; payloadIdx < data.numPayloads;
                       payloadIdx++) {
                    TP_DCHECK_EQ(
                        memcmp(
                            allocation.payloads[payloadIdx].data,
                            data.expectedPayload[payloadIdx].get(),
                            descriptor.payloads[payloadIdx].length),
                        0);
                  }
                } else {
                  TP_DCHECK_EQ(allocation.payloads.size(), 0);
                }
                if (data.tensorSize > 0) {
                  TP_DCHECK_EQ(allocation.tensors.size(), data.numTensors);
                  for (size_t tensorIdx = 0; tensorIdx < data.numTensors;
                       tensorIdx++) {
                    if (data.tensorType == TensorType::kCpu) {
                      TP_DCHECK_EQ(
                          memcmp(
                              allocation.tensors[tensorIdx]
                                  .buffer.unwrap<CpuBuffer>()
                                  .ptr,
                              data.expectedCpuTensor[tensorIdx].get(),
                              descriptor.tensors[tensorIdx].length),
                          0);
                    } else if (data.tensorType == TensorType::kCuda) {
                      // No (easy) way to do a memcmp with CUDA, I
                      // believe...
                    } else {
                      TP_THROW_ASSERT() << "Unknown tensor type";
                    }
                  }
                } else {
                  TP_DCHECK_EQ(allocation.tensors.size(), 0);
                }
                if (numWarmUps > 0) {
                  numWarmUps -= 1;
                } else {
                  numRoundTrips -= 1;
                }
                if (numRoundTrips > 0) {
                  clientPingPongNonBlock(
                      pipe,
                      numWarmUps,
                      numRoundTrips,
                      doneProm,
                      data,
                      measurements);
                } else {
                  printMultiDeviceMeasurements(measurements, data.payloadSize);
                  doneProm.set_value();
                }
              });
        });
      });
}