in tensorflow_networking/mpi/mpi_rendezvous_mgr.cc [61:140]
void MPIRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
Status s = Status::OK();
MPIRequestTensorCall* rendezvous_call = new MPIRequestTensorCall();
VLOG(2) << "MPI User requested " << parsed.FullKey()
<< " @ step: " << step_id_;
std::string src_task = strings::StrCat(
parsed.src.job, ":", parsed.src.replica, ":", parsed.src.task);
const int dst = mpiutils_->GetSourceID(src_task);
Device* dst_device;
if (s.ok()) {
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
CHECK(s.ok()) << "Device lookup failed";
} else {
done(s, Args(), recv_args, Tensor{}, false);
return;
}
// Set properties of the request object and create the request function
rendezvous_call->Init(parsed, step_id_);
std::function<void()> request_call = [parsed, dst, rendezvous_call]() {
// Use MPI_Alloc_mem here to force allocation inside MPI thread
// this is not optimal, but prevents memory corruption and segmentation
// faults during inter-server transfers...
MPI_CHECK(MPI_Alloc_mem(rendezvous_call->request_buffer_size_,
MPI_INFO_NULL, &rendezvous_call->request_buffer_));
rendezvous_call->req_.SerializeToArray(
rendezvous_call->request_buffer_,
rendezvous_call->request_buffer_size_);
MPI_CHECK(MPI_Isend(rendezvous_call->request_buffer_,
rendezvous_call->request_buffer_size_, MPI_CHAR, dst,
TAG_REQTENSOR, MPI_COMM_WORLD,
&rendezvous_call->mpi_request_));
};
// Create the function which is called when the Tensor is send by remote
const int64 temp1 = step_id_;
rendezvous_call->recv_call_ =
[this, parsed, recv_args, done, dst, temp1,
rendezvous_call](MPIRecvTensorResponse mpi_response) {
Status s;
Device* dst_device;
if (s.ok()) {
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
CHECK(s.ok()) << "Device lookup failed";
}
VLOG(3) << "MPI Received tensor " << parsed.FullKey()
<< " @ step: " << temp1
<< " single-send: " << mpi_response.singlesend();
Tensor val;
if (mpi_response.singlesend()) {
dst_device->MakeTensorFromProto(mpi_response.response().tensor(),
recv_args.alloc_attrs, &val);
} else {
TensorResponse tr;
tr.InitAlloc(dst_device, recv_args.alloc_attrs);
tr.InitPartial(mpi_response.response(), AllocationAttributes());
const size_t nBytes = tr.tensor().TotalBytes();
void* data = const_cast<void*>(DMAHelper::base(&tr.tensor()));
MPI_Status status;
MPI_CHECK(MPI_Recv(data, static_cast<int>(nBytes), MPI_BYTE, dst,
TAG_SENDTENSOR2, MPI_COMM_WORLD, &status));
val = std::move(tr.tensor());
}
done(s, Args(), recv_args, val, mpi_response.response().is_dead());
};
MPIRendezvousMgr* mgr =
reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call),
rendezvous_call);
}