static void serverPongPingNonBlock()

in tensorpipe/benchmark/benchmark_pipe.cc [173:342]


static void serverPongPingNonBlock(
    std::shared_ptr<Pipe> pipe,
    int& numWarmUps,
    int& numRoundTrips,
    std::promise<void>& doneProm,
    Data& data,
    Measurements& measurements) {
#if USE_NCCL
  for (int iterIdx = 0; iterIdx < numWarmUps + numRoundTrips; iterIdx++) {
    // TODO Handle multiple tensors.
    TP_NCCL_CHECK(ncclRecv(
        data.temporaryCudaTensor[0].get(),
        data.tensorSize,
        ncclInt8,
        1,
        data.ncclComm.get(),
        data.cudaStream.get()));
    TP_NCCL_CHECK(ncclSend(
        data.temporaryCudaTensor[0].get(),
        data.tensorSize,
        ncclInt8,
        1,
        data.ncclComm.get(),
        data.cudaStream.get()));
  }
  doneProm.set_value();
  return;
#endif // USE_NCCL
  pipe->readDescriptor(
      [pipe, &numWarmUps, &numRoundTrips, &doneProm, &data, &measurements](
          const Error& error, Descriptor descriptor) {
        TP_THROW_ASSERT_IF(error) << error.what();
        Allocation allocation;
        TP_DCHECK_EQ(descriptor.metadata, data.expectedMetadata);
        if (data.payloadSize > 0) {
          TP_DCHECK_EQ(descriptor.payloads.size(), data.numPayloads);
          allocation.payloads.resize(data.numPayloads);
          for (size_t payloadIdx = 0; payloadIdx < data.numPayloads;
               payloadIdx++) {
            TP_DCHECK_EQ(
                descriptor.payloads[payloadIdx].metadata,
                data.expectedPayloadMetadata[payloadIdx]);
            TP_DCHECK_EQ(
                descriptor.payloads[payloadIdx].length, data.payloadSize);
            allocation.payloads[payloadIdx].data =
                data.temporaryPayload[payloadIdx].get();
          }
        } else {
          TP_DCHECK_EQ(descriptor.payloads.size(), 0);
        }
        if (data.tensorSize > 0) {
          TP_DCHECK_EQ(descriptor.tensors.size(), data.numTensors);
          allocation.tensors.resize(data.numTensors);
          for (size_t tensorIdx = 0; tensorIdx < data.numTensors; tensorIdx++) {
            TP_DCHECK_EQ(
                descriptor.tensors[tensorIdx].metadata,
                data.expectedTensorMetadata[tensorIdx]);
            TP_DCHECK_EQ(descriptor.tensors[tensorIdx].length, data.tensorSize);
            if (data.tensorType == TensorType::kCpu) {
              allocation.tensors[tensorIdx].buffer = CpuBuffer{
                  .ptr = data.temporaryCpuTensor[tensorIdx].get(),
              };
            } else if (data.tensorType == TensorType::kCuda) {
              allocation.tensors[tensorIdx].buffer = CudaBuffer{
                  .ptr = data.temporaryCudaTensor[tensorIdx].get(),
                  .stream = data.cudaStream.get(),
              };
            } else {
              TP_THROW_ASSERT() << "Unknown tensor type";
            }
          }
        } else {
          TP_DCHECK_EQ(descriptor.tensors.size(), 0);
        }

        pipe->read(
            allocation,
            [pipe,
             &numWarmUps,
             &numRoundTrips,
             &doneProm,
             &data,
             &measurements,
             descriptor{std::move(descriptor)},
             allocation](const Error& error) {
              TP_THROW_ASSERT_IF(error) << error.what();

              Message message;
              if (data.payloadSize > 0) {
                TP_DCHECK_EQ(allocation.payloads.size(), data.numPayloads);
                message.payloads.resize(data.numPayloads);
                for (size_t payloadIdx = 0; payloadIdx < data.numPayloads;
                     payloadIdx++) {
                  TP_DCHECK_EQ(
                      descriptor.payloads[payloadIdx].length, data.payloadSize);
                  TP_DCHECK_EQ(
                      memcmp(
                          allocation.payloads[payloadIdx].data,
                          data.expectedPayload[payloadIdx].get(),
                          descriptor.payloads[payloadIdx].length),
                      0);
                  message.payloads[payloadIdx] = {
                      .data = data.expectedPayload[payloadIdx].get(),
                      .length = descriptor.payloads[payloadIdx].length,
                  };
                }
              } else {
                TP_DCHECK_EQ(allocation.payloads.size(), 0);
              }
              if (data.tensorSize > 0) {
                TP_DCHECK_EQ(allocation.tensors.size(), data.numTensors);
                message.tensors.resize(data.numTensors);
                for (size_t tensorIdx = 0; tensorIdx < data.numTensors;
                     tensorIdx++) {
                  TP_DCHECK_EQ(
                      descriptor.tensors[tensorIdx].length, data.tensorSize);
                  if (data.tensorType == TensorType::kCpu) {
                    TP_DCHECK_EQ(
                        memcmp(
                            allocation.tensors[tensorIdx]
                                .buffer.unwrap<CpuBuffer>()
                                .ptr,
                            data.expectedCpuTensor[tensorIdx].get(),
                            descriptor.tensors[tensorIdx].length),
                        0);
                  } else if (data.tensorType == TensorType::kCuda) {
                    // No (easy) way to do a memcmp with CUDA, I believe...
                  } else {
                    TP_THROW_ASSERT() << "Unknown tensor type";
                  }
                  message.tensors[tensorIdx] = {
                      .buffer = allocation.tensors[tensorIdx].buffer,
                      .length = descriptor.tensors[tensorIdx].length,
                      .targetDevice =
                          descriptor.tensors[tensorIdx].sourceDevice,
                  };
                }
              } else {
                TP_DCHECK_EQ(allocation.tensors.size(), 0);
              }

              pipe->write(
                  std::move(message),
                  [pipe,
                   &numWarmUps,
                   &numRoundTrips,
                   &doneProm,
                   &data,
                   &measurements](const Error& error) {
                    TP_THROW_ASSERT_IF(error) << error.what();
                    if (numWarmUps > 0) {
                      numWarmUps -= 1;
                    } else {
                      numRoundTrips -= 1;
                    }
                    if (numRoundTrips > 0) {
                      serverPongPingNonBlock(
                          pipe,
                          numWarmUps,
                          numRoundTrips,
                          doneProm,
                          data,
                          measurements);
                    } else {
                      doneProm.set_value();
                    }
                  });
            });
      });
}