in tensorflow_networking/gdr/gdr_memory_manager.cc [419:533]
void GdrMemoryManager::TensorFromTransportOptions(
Tensor* tensor, const ::google::protobuf::Any& transport_options,
Device* device, DeviceContext* device_context, bool on_host,
StatusCallback done) {
RemoteMemoryRegion remote_mr;
if (!transport_options.UnpackTo(&remote_mr)) {
done(errors::NotFound("No RDMA transport options found"));
return;
}
rdma_cm_id* id = nullptr;
{
decltype(clients_)::iterator iter;
bool success;
mutex_lock l(client_mu_);
std::tie(iter, success) = clients_.insert(
std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
RdmaEndpointPtr(nullptr, EndpointDeleter)));
if (success || iter->second.get() == nullptr) {
Status s =
CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
if (!s.ok()) {
done(s);
return;
}
}
id = iter->second.get();
}
ibv_mr* mr = FindMemoryRegion(tensor);
const TensorBuffer* buffer = DMAHelper::buffer(tensor);
const Tensor* copy = nullptr;
if (mr == nullptr) {
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_nic_compatible(true);
alloc_attrs.set_on_host(true);
Allocator* alloc = device->GetAllocator(alloc_attrs);
copy = new Tensor(alloc, tensor->dtype(), tensor->shape());
mr = FindMemoryRegion(copy);
buffer = DMAHelper::buffer(copy);
if (mr == nullptr) {
done(errors::Unavailable("Cannot find pinned memory region"));
delete copy;
return;
}
}
uint64_t start = Env::Default()->NowMicros();
TensorKey tensor_key = remote_mr.tensor_key();
StatusCallback callback = [done, copy, device, device_context, on_host,
tensor, start, tensor_key](const Status& s) {
if (!s.ok()) {
done(s);
if (copy) {
delete copy;
}
return;
}
VLOG(2) << "RDMA of tensor " << tensor_key << " of size "
<< DMAHelper::buffer(tensor)->size() << " took "
<< (Env::Default()->NowMicros() - start) << " micros";
if (copy && device->tensorflow_gpu_device_info() && !on_host) {
device_context->CopyCPUTensorToDevice(copy, device, tensor,
[done, copy](const Status& s) {
done(s);
delete copy;
});
} else if (copy) {
std::memcpy(DMAHelper::buffer(tensor)->data(),
DMAHelper::buffer(copy)->data(),
DMAHelper::buffer(copy)->size());
done(s);
delete copy;
} else {
done(s);
}
};
{
mutex_lock l(callback_mu_);
if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) {
tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback)));
} else {
done(errors::Unavailable("Received duplicated tensor key"));
if (copy) {
delete copy;
}
return;
}
}
if (rdma_post_read(id, reinterpret_cast<void*>(tensor_key), buffer->data(),
buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(),
remote_mr.rkey())) {
done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
{
mutex_lock l(callback_mu_);
auto iter = tensor_callbacks_.find(tensor_key);
if (iter != std::end(tensor_callbacks_)) {
tensor_callbacks_.erase(iter);
}
}
if (copy) {
delete copy;
}
}
}