in tensorflow_networking/gdr/gdr_memory_manager.cc [268:357]
void GdrMemoryManager::Run() {
stopped_ = false;
while (!stopped_) {
rdma_cm_id* id = nullptr;
// Accept incoming connections
if (!rdma_get_request(listening_.get(), &id)) {
if (!rdma_accept(id, nullptr)) {
LOG(INFO) << "Accepted new RDMA connection";
for (int i = 0; i < 1024; i++) {
if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
EndpointDeleter(id);
continue;
}
}
server_clients_.push_back({id, EndpointDeleter});
}
}
// Polling server side work completions
for (const auto& client : server_clients_) {
ibv_wc wc[32];
int ret = ibv_poll_cq(client->recv_cq, 32, wc);
if (ret < 0) {
LOG(ERROR) << "ibv_poll_cq failed";
continue;
}
for (int i = 0; i < ret; i++) {
if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
}
if (wc[i].status != 0) {
LOG(ERROR) << ibv_wc_status_str(wc[i].status);
}
TensorKey tensor_key = ntohl(wc[i].imm_data);
if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) {
perror("rdma_post_recvv");
LOG(ERROR) << "rdma_post_recvv failed";
}
mutex_lock l(buf_mu_);
auto iter = tensor_buffers_.find(tensor_key);
if (iter == std::end(tensor_buffers_)) {
LOG(ERROR) << "Cannot find tensor buffer for tensor key "
<< tensor_key;
} else {
const TensorBuffer* buffer = iter->second;
buffer->Unref();
tensor_buffers_.erase(iter);
}
}
}
// Polling client side work completions
if (client_mu_.try_lock()) {
for (const auto& client : clients_) {
ibv_wc wc[32];
int ret = ibv_poll_cq(client.second->send_cq, 32, wc);
for (int i = 0; i < ret; i++) {
Status s;
if (wc[i].status) {
s = errors::Unavailable(ibv_wc_status_str(wc[i].status));
} else {
s = Status::OK();
}
TensorKey key = wc[i].wr_id;
ibv_send_wr wr = {};
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.imm_data = htonl(key);
ibv_send_wr* bad_wr;
if (ibv_post_send(client.second->qp, &wr, &bad_wr)) {
LOG(ERROR) << strerror(errno)
<< ": ibv_post_send failed for tensor_key " << key;
}
mutex_lock l(callback_mu_);
auto iter = tensor_callbacks_.find(key);
if (iter != std::end(tensor_callbacks_)) {
iter->second(s);
tensor_callbacks_.erase(iter);
} else {
LOG(WARNING) << "Cannot find client callback with tensor key "
<< key;
}
}
}
client_mu_.unlock();
}
}
}