void MPIRendezvousMgr::AddRequest()

in tensorflow_networking/mpi/mpi_rendezvous_mgr.cc [150:269]


void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
                                  const int mpi_dst) {
  TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique(
      request.request_id(), "RecvTensor (MPIRendezvousMgr)", request));
  const int64 step_id = request.step_id();
  const std::string& key = request.rendezvous_key();
  Rendezvous::ParsedKey parsed;
  TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));

  MPIRecvTensorCallBack send_cb = [this, mpi_dst, parsed](
      const Status& status, const Rendezvous::Args& send_args,
      const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead,
      MPISendTensorCall* mpi_send_call) {
    // TODO(jbedorf) this should be a loop over max size
    CHECK(mpi_send_call->mRes_.ByteSize() < INT_MAX)
        << "Buffer too large for single transfer";
    MPI_CHECK(MPI_Alloc_mem(mpi_send_call->mRes_.ByteSize(), MPI_INFO_NULL,
                            &mpi_send_call->send_buffer_));
    mpi_send_call->mRes_.SerializeToArray(mpi_send_call->send_buffer_,
                                          mpi_send_call->mRes_.ByteSize());

    MPI_CHECK(MPI_Isend(mpi_send_call->send_buffer_,
                        static_cast<int>(mpi_send_call->mRes_.ByteSize()),
                        MPI_CHAR, mpi_dst, TAG_SENDTENSOR, MPI_COMM_WORLD,
                        &(mpi_send_call->msg1_)));
    MPI_CHECK(MPI_Test(&mpi_send_call->msg1_, &mpi_send_call->done1_,
                       MPI_STATUS_IGNORE));

    if (!mpi_send_call->mRes_.singlesend()) {
      const int tensor_size = static_cast<int>(val.TotalBytes());
      void* temp = const_cast<void*>(DMAHelper::base(&val));

      // If the MPI library is not GPU aware there should be a data transfer
      // here to get the data on the host.
      // if(src_dev->tensorflow_gpu_device_info()) //memcpy to send_buffer2_

      // TODO(jbedorf)  this should be a loop over max size
      MPI_CHECK(MPI_Isend(temp, tensor_size, MPI_CHAR, mpi_dst, TAG_SENDTENSOR2,
                          MPI_COMM_WORLD, &mpi_send_call->msg2_));
      mpi_send_call->done2_ = 0;
    }
    return mpi_send_call;
  };

  // Wrapper around the read callback to place the callback on our queue
  Rendezvous::DoneCallback done_cb = [this, parsed, step_id, send_cb](
      const Status& status, const Rendezvous::Args& send_args,
      const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
    if (!status.ok()) {
      CHECK(status.ok()) << "RecvLocalAsync was not ok, key: "
                         << parsed.FullKey() << " step: " << step_id
                         << " error message: " << status.error_message();
      return;
    }

    VLOG(3) << "MPI Sending tensor " << parsed.FullKey()
            << " @ step: " << step_id << std::endl;

    auto mpi_send_call = new MPISendTensorCall();
    mpi_send_call->Init(parsed, step_id, is_dead);

    Device* src_dev = nullptr;
    Status s = this->worker_env_2->device_mgr->LookupDevice(parsed.src_device,
                                                            &src_dev);
    CHECK(s.ok()) << "src device not found";

    // Control if shape and data should be send together or if we can
    // optimize it in two different transfers, thereby reducing memory
    // copies
    bool doOptimalTransfer = true;
    if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false;
    if (val.TotalBytes() < 1024) doOptimalTransfer = false;

    doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_;

    if (doOptimalTransfer) {
      // First send the Tensor description and in a follow up transfer the
      // data
      mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype(
          val.dtype());
      val.shape().AsProto(mpi_send_call->mRes_.mutable_response()
                              ->mutable_tensor()
                              ->mutable_tensor_shape());
      mpi_send_call->mRes_.set_singlesend(false);
    } else {
      // Send the Tensor description and data in a single transfer
      if (src_dev->tensorflow_gpu_device_info() &&
          (!send_args.alloc_attrs.on_host())) {
        Notification n;
        GPUUtil::SetProtoFromGPU(
            val, src_dev, send_args.device_context,
            mpi_send_call->mRes_.mutable_response()->mutable_tensor(), is_dead,
            [&n, &s](const Status& s_) {
              s = s_;
              n.Notify();
            });
        n.WaitForNotification();
      } else {
        val.AsProtoTensorContent(
            mpi_send_call->mRes_.mutable_response()->mutable_tensor());
      }
    }

    std::function<MPISendTensorCall*()> res = std::bind(
        send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);

    SendQueueEntry req(string(parsed.FullKey()), std::move(res));

    this->QueueSendRequest(req);

    // Wait for the notification that indicates the tensor has been
    // successfully transmitted to the remote process. Only needed if we
    // have not parsed the tensor to proto
    if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification();
  };  // done_cb

  worker_env_2->compute_pool->Schedule([this, step_id, parsed, done_cb]() {
    this->RecvLocalAsync(step_id, parsed, done_cb);
  });
}