in src/nccl_ofi_rdma.cpp [2720:2800]
static int test(nccl_net_ofi_req_t *base_req, int *done, int *size)
{
int ret = 0;
nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)base_req;
*done = 0;
assert(req->type == NCCL_OFI_RDMA_WRITE ||
req->type == NCCL_OFI_RDMA_READ ||
req->type == NCCL_OFI_RDMA_SEND ||
req->type == NCCL_OFI_RDMA_RECV ||
req->type == NCCL_OFI_RDMA_FLUSH);
/* Retrieve and validate comm */
nccl_net_ofi_comm_t *base_comm = req->comm;
assert(base_comm != NULL);
/* Retrieve and validate endpoint */
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_comm->ep;
assert(ep != NULL);
/* Process more completions unless the current request is
* completed */
if (req->state != NCCL_OFI_RDMA_REQ_COMPLETED
&& OFI_LIKELY(req->state != NCCL_OFI_RDMA_REQ_ERROR)) {
ret = ofi_process_cq(ep);
if (OFI_UNLIKELY(ret != 0))
goto exit;
}
/* Determine whether the request has finished without error and free if done */
if (OFI_LIKELY(req->state == NCCL_OFI_RDMA_REQ_COMPLETED)) {
size_t req_size;
nccl_net_ofi_mutex_lock(&req->req_lock);
req_size = req->size;
nccl_net_ofi_mutex_unlock(&req->req_lock);
if (size)
*size = req_size;
/* Mark as done */
*done = 1;
if (req->type == NCCL_OFI_RDMA_SEND || req->type == NCCL_OFI_RDMA_RECV) {
/* Mark as complete in message buffer */
nccl_ofi_msgbuff_t *msgbuff;
if (req->type == NCCL_OFI_RDMA_SEND) {
msgbuff = ((nccl_net_ofi_rdma_send_comm_t *)base_comm)->msgbuff;
} else if (req->type == NCCL_OFI_RDMA_RECV) {
msgbuff = ((nccl_net_ofi_rdma_recv_comm_t *)base_comm)->msgbuff;
} else {
NCCL_OFI_WARN("Unexpected request type: %d", req->type);
ret = -EINVAL;
goto exit;
}
nccl_ofi_msgbuff_status_t stat;
nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_complete(msgbuff, req->msg_seq_num, &stat);
if (OFI_UNLIKELY(mb_res != NCCL_OFI_MSGBUFF_SUCCESS)) {
NCCL_OFI_WARN("Invalid result of msgbuff_complete for msg %hu", req->msg_seq_num);
ret = -EINVAL;
goto exit;
}
}
if (req->type == NCCL_OFI_RDMA_SEND) {
NCCL_OFI_TRACE_SEND_END(req->dev_id, base_comm, req);
} else if (req->type == NCCL_OFI_RDMA_RECV) {
NCCL_OFI_TRACE_RECV_END(req->dev_id, base_comm, req);
}
assert(req->free);
req->free(req, true);
} else if (OFI_UNLIKELY(req->state == NCCL_OFI_RDMA_REQ_ERROR)) {
ret = -EINVAL;
goto exit;
}
exit:
return ret;
}