in src/nccl_ofi_sendrecv.cpp [1216:1363]
static int sendrecv_recv_comm_flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
int *sizes, nccl_net_ofi_mr_handle_t **mhandles,
nccl_net_ofi_req_t **base_req)
{
int ret = 0;
nccl_net_ofi_sendrecv_recv_comm_t *r_comm =
(nccl_net_ofi_sendrecv_recv_comm_t *)recv_comm;
nccl_net_ofi_sendrecv_req_t *req = NULL;
ssize_t rc = 0;
uint64_t cuda_key = 0ULL;
nccl_net_ofi_sendrecv_mr_handle_t *mr_handle = NULL;
void *data = NULL;
void *flush_mr_desc = NULL;
int dev_id = recv_comm->base.dev_id;
int flush_n = -1;
auto **mr_handles = reinterpret_cast<nccl_net_ofi_sendrecv_mr_handle_t **>(mhandles);
if (ofi_nccl_gdr_flush_disable() || support_gdr == GDR_UNSUPPORTED)
goto exit;
#if HAVE_CUDA
if (cuda_flush) {
ret = nccl_net_ofi_cuda_flush_gpudirect_rdma_writes();
if (ret != 0) {
NCCL_OFI_WARN("Error performing CUDA GDR flush");
}
goto exit;
}
#endif
/* Plugin only supports one receive per request */
assert(n <= NCCL_OFI_MAX_RECVS);
/*
* Find the non-zero request for which we will issue flush.
* A single operation can flush all request at once.
*/
for (int recv_n = 0; recv_n < n; recv_n++) {
if (sizes[recv_n] != 0) {
flush_n = recv_n;
break;
}
}
if (flush_n == -1) {
/*
* Flush is an expensive operation. So, don't send fi_read for
* 0-sized messages. Since, NCCL issues flush for every irecv(),
* we guarantee to sync data to GPU even without it.
*/
goto exit;
}
if (mr_handles && mr_handles[flush_n]) {
mr_handle = mr_handles[flush_n];
}
data = buffers[flush_n];
/* Support only max_requests inflight requests. */
if (OFI_UNLIKELY(r_comm->num_inflight_reqs == NCCL_OFI_MAX_REQUESTS)) {
ret = -ENOSPC;
NCCL_OFI_WARN("Can not support more than %d inflight requests",
NCCL_OFI_MAX_REQUESTS);
goto exit;
}
/* Allocate NCCL OFI request */
req = sendrecv_allocate_req(r_comm->nccl_ofi_reqs_fl);
if (OFI_UNLIKELY(req == NULL)) {
ret = -ENOTSUP;
NCCL_OFI_WARN("Unable to get NCCL OFI request for device %d",
dev_id);
goto exit;
}
req->comm = &r_comm->base.base;
req->dev_id = dev_id;
req->direction = NCCL_OFI_SENDRECV_RECV;
if (r_comm->flush_buff.mr_handle != NULL) {
/* Not checking for NULL flush_mr_desc as fi_mr_desc()
* returns valid descriptors by valid handles */
flush_mr_desc = fi_mr_desc(r_comm->flush_buff.mr_handle->mr);
}
if (mr_handle->mr != nullptr) {
/* Extract remote key */
cuda_key = fi_mr_key(mr_handle->mr);
if (OFI_UNLIKELY(cuda_key == FI_KEY_NOTAVAIL)) {
ret = -ENOTSUP;
NCCL_OFI_WARN("Memory registration may not have completed.");
goto error;
}
}
NCCL_OFI_TRACE_FLUSH_SENDRECV(req, base_req);
/* Issue RDMA read */
do {
rc = fi_read(r_comm->local_ep, r_comm->flush_buff.host_buffer,
r_comm->flush_buff.size,
flush_mr_desc,
r_comm->local_ep_addr,
(uint64_t)(virt_addr_mr ? data : 0),
cuda_key, sendrecv_req_get_ofi_context(req));
if (rc == 0) {
break;
} else if (rc == -FI_EAGAIN) {
/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t *ep =
(nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep;
if (OFI_UNLIKELY(ep == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid endpoint provided");
goto error;
}
/*
* Process completions so that you have enough
* resources for issuing fi_read
*/
ret = sendrecv_cq_process(ep->cq);
if (OFI_UNLIKELY(ret != 0))
goto error;
} else {
NCCL_OFI_WARN("Unable to issue read operation for dev %d. RC: %zd, ERROR: %s",
dev_id, rc, fi_strerror(-rc));
ret = -ENOTSUP;
goto error;
}
} while (true);
(r_comm->num_inflight_reqs)++;
/* Set request size */
req->size = r_comm->flush_buff.size;
*base_req = &req->base;
return ret;
error:
if (req)
sendrecv_recv_comm_free_req(r_comm, dev_id, req, false);
exit:
*base_req = NULL;
return ret;
}