void GdrMemoryManager::TensorFromTransportOptions()

in tensorflow_networking/gdr/gdr_memory_manager.cc [419:533]


void GdrMemoryManager::TensorFromTransportOptions(
    Tensor* tensor, const ::google::protobuf::Any& transport_options,
    Device* device, DeviceContext* device_context, bool on_host,
    StatusCallback done) {
  RemoteMemoryRegion remote_mr;
  if (!transport_options.UnpackTo(&remote_mr)) {
    done(errors::NotFound("No RDMA transport options found"));
    return;
  }

  rdma_cm_id* id = nullptr;
  {
    decltype(clients_)::iterator iter;
    bool success;
    mutex_lock l(client_mu_);
    std::tie(iter, success) = clients_.insert(
        std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
                       RdmaEndpointPtr(nullptr, EndpointDeleter)));
    if (success || iter->second.get() == nullptr) {
      Status s =
          CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
      if (!s.ok()) {
        done(s);
        return;
      }
    }
    id = iter->second.get();
  }

  ibv_mr* mr = FindMemoryRegion(tensor);
  const TensorBuffer* buffer = DMAHelper::buffer(tensor);

  const Tensor* copy = nullptr;

  if (mr == nullptr) {
    AllocatorAttributes alloc_attrs;
    alloc_attrs.set_gpu_compatible(true);
    alloc_attrs.set_nic_compatible(true);
    alloc_attrs.set_on_host(true);
    Allocator* alloc = device->GetAllocator(alloc_attrs);
    copy = new Tensor(alloc, tensor->dtype(), tensor->shape());

    mr = FindMemoryRegion(copy);
    buffer = DMAHelper::buffer(copy);
    if (mr == nullptr) {
      done(errors::Unavailable("Cannot find pinned memory region"));
      delete copy;
      return;
    }
  }

  uint64_t start = Env::Default()->NowMicros();

  TensorKey tensor_key = remote_mr.tensor_key();

  StatusCallback callback = [done, copy, device, device_context, on_host,
                             tensor, start, tensor_key](const Status& s) {
    if (!s.ok()) {
      done(s);
      if (copy) {
        delete copy;
      }
      return;
    }

    VLOG(2) << "RDMA of tensor " << tensor_key << " of size "
            << DMAHelper::buffer(tensor)->size() << " took "
            << (Env::Default()->NowMicros() - start) << " micros";

    if (copy && device->tensorflow_gpu_device_info() && !on_host) {
      device_context->CopyCPUTensorToDevice(copy, device, tensor,
                                            [done, copy](const Status& s) {
                                              done(s);
                                              delete copy;
                                            });
    } else if (copy) {
      std::memcpy(DMAHelper::buffer(tensor)->data(),
                  DMAHelper::buffer(copy)->data(),
                  DMAHelper::buffer(copy)->size());
      done(s);
      delete copy;
    } else {
      done(s);
    }
  };

  {
    mutex_lock l(callback_mu_);
    if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) {
      tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback)));
    } else {
      done(errors::Unavailable("Received duplicated tensor key"));
      if (copy) {
        delete copy;
      }
      return;
    }
  }

  if (rdma_post_read(id, reinterpret_cast<void*>(tensor_key), buffer->data(),
                     buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(),
                     remote_mr.rkey())) {
    done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
    {
      mutex_lock l(callback_mu_);
      auto iter = tensor_callbacks_.find(tensor_key);
      if (iter != std::end(tensor_callbacks_)) {
        tensor_callbacks_.erase(iter);
      }
    }
    if (copy) {
      delete copy;
    }
  }
}