in tensorflow_networking/gdr/gdr_worker.cc [45:146]
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
Status s = recv_tensor_recent_request_ids_.TrackUnique(
request->request_id(), "RecvTensor (GdrWorker)", *request);
if (!s.ok()) {
done(s);
return;
}
const int64 step_id = request->step_id();
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
}
if (!s.ok()) {
done(s);
return;
}
// Request the tensor associated with the rendezvous key. Any time
// while waiting for the tensor to be produced, up until the start
// of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
const bool dma_ok = request->dma_ok();
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[this, opts, response, done, src_dev, request, dma_ok](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args&, const Tensor& val, const bool is_dead) {
opts->ClearCancelCallback();
if (status.ok()) {
// DMA can only be used for Tensors that do not fall into
// the following three odd edge cases: 1) a zero-size
// buffer, 2) a dead tensor which has an uninit value, and
// 3) the tensor has the on_host allocation attribute,
// i.e. it's in CPU RAM *independent of its assigned
// device type*.
const bool on_host = send_args.alloc_attrs.on_host();
if (val.TotalBytes() > 1024 && (!is_dead) &&
DMAHelper::CanUseDMA(&val) && dma_ok) {
// DMA cases.
RecvTensorResponse* proto = new RecvTensorResponse;
proto->set_is_dead(is_dead);
proto->set_send_start_micros(Env::Default()->NowMicros());
TensorProto* tensor_proto = proto->mutable_tensor();
tensor_proto->set_dtype(val.dtype());
val.shape().AsProto(tensor_proto->mutable_tensor_shape());
auto transport_options = proto->mutable_transport_options();
remote_memory_manager_->TransportOptionsFromTensor(
transport_options, val, src_dev, send_args.device_context,
on_host, [proto, done, response](const Status& s) {
if (s.ok()) {
grpc::EncodeRecvTensorResponseToByteBuffer(*proto,
response);
done(Status::OK());
} else {
done(s);
}
delete proto;
});
} else {
// Non-DMA cases.
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
DeviceContext* send_dev_context = send_args.device_context;
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
CHECK(send_dev_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on an accelerator device. Uses the device_context to
// fill the copy on host.
StatusCallback copy_ready = [response, done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
grpc::EncodeTensorToByteBuffer(is_dead, *copy, false, response);
done(s);
delete copy;
};
send_dev_context->CopyDeviceTensorToCPU(
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
} else {
grpc::EncodeTensorToByteBuffer(is_dead, val, false, response);
done(Status::OK());
}
}
} else {
// !s.ok()
done(status);
}
});
}