void RdmaMgr::SetupChannels()

in tensorflow_networking/verbs/rdma_mgr.cc [58:131]


void RdmaMgr::SetupChannels() {
  for (const auto& p : channel_table_) {
    string worker_name = p.first;
    RDMA_LOG(2) << "Connecting to remote node " << worker_name;
    RdmaChannel* rc = p.second;
    GetRemoteAddressRequest req;
    GetRemoteAddressResponse resp;
    // get the channel cache
    SharedGrpcChannelPtr client_channel =
        channel_cache_->FindWorkerChannel(worker_name);
    GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
    CHECK(client != nullptr) << "No worker known as " << worker_name;

    // setting up request
    req.set_host_name(local_worker_);
    Channel* channel_info = req.mutable_channel();
    channel_info->set_lid(rc->self_.lid);
    channel_info->set_qpn(rc->self_.qpn);
    channel_info->set_psn(rc->self_.psn);
    channel_info->set_snp(rc->self_.snp);
    channel_info->set_iid(rc->self_.iid);
    for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
      MemoryRegion* mr = req.add_mr();
      mr->set_remote_addr(
          reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
      mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
    }
    // synchronous call
    Status s;
    int attempts = 0;
    static const int max_num_attempts = 5;
    do {
      s = client->GetRemoteAddress(&req, &resp);
      // save obtained remote addresses
      // connect to the remote channel
      if (s.ok()) {
        CHECK(worker_name.compare(resp.host_name()) == 0);
        RdmaAddress ra;
        ra.lid = resp.channel().lid();
        ra.qpn = resp.channel().qpn();
        ra.psn = resp.channel().psn();
        ra.snp = resp.channel().snp();
        ra.iid = resp.channel().iid();
        rc->SetRemoteAddress(ra, false);
        rc->Connect();
        int i = 0;
        int idx[] = {1, 0};
        for (const auto& mr : resp.mr()) {
          // the connections are crossed, i.e.
          // local tx_message_buffer <---> remote rx_message_buffer_
          // local rx_message_buffer <---> remote tx_message_buffer_
          // hence idx[] = {1, 0}.
          RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
          RemoteMR rmr;
          rmr.remote_addr = mr.remote_addr();
          rmr.rkey = mr.rkey();
          rb->SetRemoteMR(rmr, false);
          i++;
        }
        CHECK(i == RdmaChannel::kNumMessageBuffers);
      } else {
        LOG(ERROR) << "Connecting to " << worker_name << ": Got "
                   << s.error_message() << ". Retrying (" << (attempts + 1)
                   << "/" << max_num_attempts << ")...";
        if (++attempts == max_num_attempts) {
          break;
        }
        worker_env_->env->SleepForMicroseconds(2000000);
      }
    } while (!s.ok());
    RDMA_LOG(0) << "Connected to remote node " << worker_name;
    delete client;
  }
}