void RdmaAdapter::Process_CQ()

in tensorflow_networking/verbs/rdma.cc [435:525]


void RdmaAdapter::Process_CQ() {
  while (true) {
    ibv_cq* cq;
    void* cq_context;
    CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
    CHECK(cq == cq_);
    ibv_ack_cq_events(cq, 1);
    CHECK(!ibv_req_notify_cq(cq_, 0));

    int ne =
        ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
    CHECK_GE(ne, 0);
    for (int i = 0; i < ne; ++i) {
      CHECK(wc_[i].status == IBV_WC_SUCCESS)
          << "Failed status \n"
          << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
          << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
      if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
        RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
        // put back a recv wr.
        rc->Recv();
        // imm_data is the index of RX buffer in the buffer table.
        uint32_t imm_data = wc_[i].imm_data;
        RdmaMessageBuffer* rb;
        RdmaMessage rm;

        if (imm_data == RDMA_IMM_DATA_ACK) {
          // receive an ack to a message
          rb = rc->tx_message_buffer_;
          rb->SetBufferStatus(remote, idle);
          rb->SendNextItem();
          continue;
        }

        if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) {
          // receive a tensor RDMA write
          uint32_t request_index = imm_data;
          RdmaTensorRequest* request = rc->GetTensorRequest(request_index);
          request->RecvTensorContent();
          continue;
        }

        // receive a control message
        rb = rc->rx_message_buffer_;
        RdmaMessage::ParseMessage(rm, rb->buffer_);
        RdmaMessageBuffer::SendAck(rc);
        RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
                    << ": Received " << MessageTypeToString(rm.type_) << " "
                    << "#" << rm.request_index_ << ": " << rm.name_;

        if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
          RdmaTensorResponse* response = rc->AddTensorResponse(rm);
          response->Start();
        } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) {
          RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
          request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_,
                                      rm.is_dead_, rm.tensor_bytes_);
#ifdef RDMA_DATA_VALIDATION
          request->RecvTensorChecksum(rm.checksum_);
#endif
        } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) {
          RdmaTensorResponse* response = rc->UpdateTensorResponse(rm);
          response->Resume();
        } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
          RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
          request->RecvErrorStatus(rm.status_);
        }
      } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
        RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id);
        RDMA_LOG(2) << "Write complete of type " << wr_id->write_type;
        switch (wr_id->write_type) {
          case RDMA_WRITE_ID_ACK:
            break;
          case RDMA_WRITE_ID_MESSAGE: {
            RdmaMessageBuffer* rb =
                reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context);
            rb->SetBufferStatus(local, idle);
            rb->SendNextItem();
            break;
          }
          case RDMA_WRITE_ID_TENSOR_WRITE: {
            RdmaTensorResponse* response =
                reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context);
            response->Destroy();
          }
        }
        delete wr_id;
      }
    }
  }
}