in tensorflow_networking/verbs/rdma_mgr.cc [58:131]
void RdmaMgr::SetupChannels() {
for (const auto& p : channel_table_) {
string worker_name = p.first;
RDMA_LOG(2) << "Connecting to remote node " << worker_name;
RdmaChannel* rc = p.second;
GetRemoteAddressRequest req;
GetRemoteAddressResponse resp;
// get the channel cache
SharedGrpcChannelPtr client_channel =
channel_cache_->FindWorkerChannel(worker_name);
GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
CHECK(client != nullptr) << "No worker known as " << worker_name;
// setting up request
req.set_host_name(local_worker_);
Channel* channel_info = req.mutable_channel();
channel_info->set_lid(rc->self_.lid);
channel_info->set_qpn(rc->self_.qpn);
channel_info->set_psn(rc->self_.psn);
channel_info->set_snp(rc->self_.snp);
channel_info->set_iid(rc->self_.iid);
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
MemoryRegion* mr = req.add_mr();
mr->set_remote_addr(
reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
}
// synchronous call
Status s;
int attempts = 0;
static const int max_num_attempts = 5;
do {
s = client->GetRemoteAddress(&req, &resp);
// save obtained remote addresses
// connect to the remote channel
if (s.ok()) {
CHECK(worker_name.compare(resp.host_name()) == 0);
RdmaAddress ra;
ra.lid = resp.channel().lid();
ra.qpn = resp.channel().qpn();
ra.psn = resp.channel().psn();
ra.snp = resp.channel().snp();
ra.iid = resp.channel().iid();
rc->SetRemoteAddress(ra, false);
rc->Connect();
int i = 0;
int idx[] = {1, 0};
for (const auto& mr : resp.mr()) {
// the connections are crossed, i.e.
// local tx_message_buffer <---> remote rx_message_buffer_
// local rx_message_buffer <---> remote tx_message_buffer_
// hence idx[] = {1, 0}.
RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
RemoteMR rmr;
rmr.remote_addr = mr.remote_addr();
rmr.rkey = mr.rkey();
rb->SetRemoteMR(rmr, false);
i++;
}
CHECK(i == RdmaChannel::kNumMessageBuffers);
} else {
LOG(ERROR) << "Connecting to " << worker_name << ": Got "
<< s.error_message() << ". Retrying (" << (attempts + 1)
<< "/" << max_num_attempts << ")...";
if (++attempts == max_num_attempts) {
break;
}
worker_env_->env->SleepForMicroseconds(2000000);
}
} while (!s.ok());
RDMA_LOG(0) << "Connected to remote node " << worker_name;
delete client;
}
}