void MPIRemoteRendezvous::RecvFromRemoteAsync()

in tensorflow_networking/mpi/mpi_rendezvous_mgr.cc [61:140]


void MPIRemoteRendezvous::RecvFromRemoteAsync(
    const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
    DoneCallback done) {
  Status s = Status::OK();
  MPIRequestTensorCall* rendezvous_call = new MPIRequestTensorCall();

  VLOG(2) << "MPI User requested " << parsed.FullKey()
          << " @ step: " << step_id_;

  std::string src_task = strings::StrCat(
      parsed.src.job, ":", parsed.src.replica, ":", parsed.src.task);
  const int dst = mpiutils_->GetSourceID(src_task);

  Device* dst_device;
  if (s.ok()) {
    s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
    CHECK(s.ok()) << "Device lookup failed";
  } else {
    done(s, Args(), recv_args, Tensor{}, false);
    return;
  }

  // Set properties of the request object and create the request function
  rendezvous_call->Init(parsed, step_id_);

  std::function<void()> request_call = [parsed, dst, rendezvous_call]() {
    // Use MPI_Alloc_mem here to force allocation inside MPI thread
    // this is not optimal, but prevents memory corruption and segmentation
    // faults during inter-server transfers...
    MPI_CHECK(MPI_Alloc_mem(rendezvous_call->request_buffer_size_,
                            MPI_INFO_NULL, &rendezvous_call->request_buffer_));
    rendezvous_call->req_.SerializeToArray(
        rendezvous_call->request_buffer_,
        rendezvous_call->request_buffer_size_);
    MPI_CHECK(MPI_Isend(rendezvous_call->request_buffer_,
                        rendezvous_call->request_buffer_size_, MPI_CHAR, dst,
                        TAG_REQTENSOR, MPI_COMM_WORLD,
                        &rendezvous_call->mpi_request_));
  };

  // Create the function which is called when the Tensor is send by remote
  const int64 temp1 = step_id_;
  rendezvous_call->recv_call_ =
      [this, parsed, recv_args, done, dst, temp1,
       rendezvous_call](MPIRecvTensorResponse mpi_response) {
        Status s;
        Device* dst_device;
        if (s.ok()) {
          s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
          CHECK(s.ok()) << "Device lookup failed";
        }

        VLOG(3) << "MPI Received tensor " << parsed.FullKey()
                << " @ step: " << temp1
                << " single-send: " << mpi_response.singlesend();

        Tensor val;
        if (mpi_response.singlesend()) {
          dst_device->MakeTensorFromProto(mpi_response.response().tensor(),
                                          recv_args.alloc_attrs, &val);
        } else {
          TensorResponse tr;
          tr.InitAlloc(dst_device, recv_args.alloc_attrs);
          tr.InitPartial(mpi_response.response(), AllocationAttributes());
          const size_t nBytes = tr.tensor().TotalBytes();
          void* data = const_cast<void*>(DMAHelper::base(&tr.tensor()));
          MPI_Status status;
          MPI_CHECK(MPI_Recv(data, static_cast<int>(nBytes), MPI_BYTE, dst,
                             TAG_SENDTENSOR2, MPI_COMM_WORLD, &status));
          val = std::move(tr.tensor());
        }

        done(s, Args(), recv_args, val, mpi_response.response().is_dead());
      };

  MPIRendezvousMgr* mgr =
      reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
  mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call),
                    rendezvous_call);
}