in tensorflow_networking/mpi/mpi_rendezvous_mgr.cc [150:269]
void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
const int mpi_dst) {
TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique(
request.request_id(), "RecvTensor (MPIRendezvousMgr)", request));
const int64 step_id = request.step_id();
const std::string& key = request.rendezvous_key();
Rendezvous::ParsedKey parsed;
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
MPIRecvTensorCallBack send_cb = [this, mpi_dst, parsed](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead,
MPISendTensorCall* mpi_send_call) {
// TODO(jbedorf) this should be a loop over max size
CHECK(mpi_send_call->mRes_.ByteSize() < INT_MAX)
<< "Buffer too large for single transfer";
MPI_CHECK(MPI_Alloc_mem(mpi_send_call->mRes_.ByteSize(), MPI_INFO_NULL,
&mpi_send_call->send_buffer_));
mpi_send_call->mRes_.SerializeToArray(mpi_send_call->send_buffer_,
mpi_send_call->mRes_.ByteSize());
MPI_CHECK(MPI_Isend(mpi_send_call->send_buffer_,
static_cast<int>(mpi_send_call->mRes_.ByteSize()),
MPI_CHAR, mpi_dst, TAG_SENDTENSOR, MPI_COMM_WORLD,
&(mpi_send_call->msg1_)));
MPI_CHECK(MPI_Test(&mpi_send_call->msg1_, &mpi_send_call->done1_,
MPI_STATUS_IGNORE));
if (!mpi_send_call->mRes_.singlesend()) {
const int tensor_size = static_cast<int>(val.TotalBytes());
void* temp = const_cast<void*>(DMAHelper::base(&val));
// If the MPI library is not GPU aware there should be a data transfer
// here to get the data on the host.
// if(src_dev->tensorflow_gpu_device_info()) //memcpy to send_buffer2_
// TODO(jbedorf) this should be a loop over max size
MPI_CHECK(MPI_Isend(temp, tensor_size, MPI_CHAR, mpi_dst, TAG_SENDTENSOR2,
MPI_COMM_WORLD, &mpi_send_call->msg2_));
mpi_send_call->done2_ = 0;
}
return mpi_send_call;
};
// Wrapper around the read callback to place the callback on our queue
Rendezvous::DoneCallback done_cb = [this, parsed, step_id, send_cb](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
if (!status.ok()) {
CHECK(status.ok()) << "RecvLocalAsync was not ok, key: "
<< parsed.FullKey() << " step: " << step_id
<< " error message: " << status.error_message();
return;
}
VLOG(3) << "MPI Sending tensor " << parsed.FullKey()
<< " @ step: " << step_id << std::endl;
auto mpi_send_call = new MPISendTensorCall();
mpi_send_call->Init(parsed, step_id, is_dead);
Device* src_dev = nullptr;
Status s = this->worker_env_2->device_mgr->LookupDevice(parsed.src_device,
&src_dev);
CHECK(s.ok()) << "src device not found";
// Control if shape and data should be send together or if we can
// optimize it in two different transfers, thereby reducing memory
// copies
bool doOptimalTransfer = true;
if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false;
if (val.TotalBytes() < 1024) doOptimalTransfer = false;
doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_;
if (doOptimalTransfer) {
// First send the Tensor description and in a follow up transfer the
// data
mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype(
val.dtype());
val.shape().AsProto(mpi_send_call->mRes_.mutable_response()
->mutable_tensor()
->mutable_tensor_shape());
mpi_send_call->mRes_.set_singlesend(false);
} else {
// Send the Tensor description and data in a single transfer
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
Notification n;
GPUUtil::SetProtoFromGPU(
val, src_dev, send_args.device_context,
mpi_send_call->mRes_.mutable_response()->mutable_tensor(), is_dead,
[&n, &s](const Status& s_) {
s = s_;
n.Notify();
});
n.WaitForNotification();
} else {
val.AsProtoTensorContent(
mpi_send_call->mRes_.mutable_response()->mutable_tensor());
}
}
std::function<MPISendTensorCall*()> res = std::bind(
send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
SendQueueEntry req(string(parsed.FullKey()), std::move(res));
this->QueueSendRequest(req);
// Wait for the notification that indicates the tensor has been
// successfully transmitted to the remote process. Only needed if we
// have not parsed the tensor to proto
if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification();
}; // done_cb
worker_env_2->compute_pool->Schedule([this, step_id, parsed, done_cb]() {
this->RecvLocalAsync(step_id, parsed, done_cb);
});
}