in src/nccl_ofi_rdma.cpp [5857:6069]
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req)
{
int ret = 0;
nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm;
nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_domain_t *domain = NULL;
nccl_net_ofi_rdma_req_t *req = NULL;
uint16_t msg_seq_num = s_comm->next_msg_seq_num;
bool polled_cq = false;
bool have_ctrl = false;
bool eager = false;
int dev_id = 0;
assert(s_comm != NULL);
if (s_comm->comm_active == false) {
NCCL_OFI_WARN("Called isend on inactive communicator");
ret = -EINVAL;
goto error;
}
/* Support only NCCL_OFI_MAX_REQUESTS inflight requests. */
if (OFI_UNLIKELY(s_comm->num_inflight_reqs == NCCL_OFI_MAX_SEND_REQUESTS)) {
ret = -EINVAL;
NCCL_OFI_WARN("Can not support more than %d inflight requests",
NCCL_OFI_MAX_SEND_REQUESTS);
goto error;
}
dev_id = s_comm->base.base.dev_id;
ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
assert(ep != NULL);
domain = rdma_endpoint_get_domain(ep);
assert(domain != NULL);
ret = process_cq_if_pending(ep);
if (ret == -EAGAIN) {
/* Network is still busy. Return NULL to NCCL. */
*base_req = NULL;
ret = 0;
goto error;
}
if (ret != 0) {
goto error;
}
/*
* TODO: Use NCCL provided tags when using grouped receives aka
* props->maxRecvs > 1.
*/
have_ctrl = false;
msg_seq_num = s_comm->next_msg_seq_num;
void *elem;
nccl_ofi_msgbuff_elemtype_t type;
nccl_ofi_msgbuff_status_t msg_stat;
nccl_ofi_msgbuff_result_t mb_res;
retry:
/* Retrive entry from message buffer for msg_seq_num index */
mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem,
&type, &msg_stat);
if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) {
if (OFI_LIKELY(type == NCCL_OFI_MSGBUFF_BUFF)) {
/*
* Received RDMA control message from receiver so
* allocate request and initiate RDMA write
*/
have_ctrl = true;
} else if (type == NCCL_OFI_MSGBUFF_REQ) {
/* Shouldn't happen: we already have a req in the message buffer */
NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num);
ret = -EINVAL;
goto error;
} else {
NCCL_OFI_WARN("Unexpected type of buffer retrieved from message buffer: %d",
type);
ret = -EINVAL;
goto error;
}
} else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) &&
(msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) {
/*
* We haven't encountered this message sequence number.
* Allocate a request so that we are able to send RDMA write
* as soon as we receive the RDMA control message.
*/
have_ctrl = false;
} else {
NCCL_OFI_WARN("Message %hu has invalid status. res = %d and stat = %d",
msg_seq_num, mb_res, msg_stat);
ret = -EINVAL;
goto error;
}
/* look for control messages and then retry the message search
to avoid unnecessary polling / queueing. */
if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) {
for (uint16_t rail_id = 0; rail_id != ep->num_control_rails; ++rail_id) {
nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_control_rail(ep, rail_id);
ret = ofi_process_cq_rail(ep, rail);
if (OFI_UNLIKELY(ret != 0)) {
goto error;
}
}
polled_cq = true;
goto retry;
}
/* NCCL versions prior to 2.24 require special handling for 0 byte
* messages when using user buffer registration. NCCL passes the base
* pointer from the user buffer, but passes the registration from the
* channel buffer, to avoid an MR cache lookup. This is fine with
* InfiniBand, where the spec says the SGE is not used for a 0 byte
* message, but is a problem for EFA, which validates the pointer / MR
* even for a 0 byte transfer.
*
* To handle this case, we use the flush buffer (note we still move 0
* bytes of data, we just need a valid SGE) instead of the provided base
* pointer and MR
*/
if (size == 0) {
data = domain->flush_buff.host_buffer;
mr_handle = domain->flush_buff.mr_handle;
}
/* Determine if this should be sent eagerly. */
eager = false;
if (!have_ctrl && (ssize_t)size <= ep->eager_send_size && s_comm->num_inflight_writes == 0) {
eager = true;
}
ret = alloc_rdma_send_req(s_comm, msg_seq_num, data,
size, mr_handle, eager, &req);
if (OFI_UNLIKELY(ret != 0)) {
goto error;
}
if (have_ctrl) {
/*
* For already received RDMA control message, populate
* the RDMA write metadata from the rx buffer
*/
nccl_net_ofi_rdma_req_t *rx_buff_req = (nccl_net_ofi_rdma_req_t *)elem;
ret = update_send_data_from_remote(s_comm, rx_buff_req, req);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Failed to copy ctrl data");
goto error;
}
/* Post if needed */
ret = check_post_rx_buff_req(rx_buff_req);
if (OFI_UNLIKELY(ret != 0)) {
goto error;
}
}
ret = insert_rdma_send_req_into_msgbuff(s_comm, dev_id, have_ctrl, &req);
if (OFI_UNLIKELY(ret != 0 || req == NULL)) {
goto free_req;
}
/*
* At this point, we've successfully inserted a new request,
* so update the num inflight
*/
(s_comm->num_inflight_reqs)++;
if (!eager) {
(s_comm->num_inflight_writes)++;
}
NCCL_OFI_TRACE_SEND(req->dev_id, size, s_comm, msg_seq_num, req, base_req);
/* Try posting RDMA write for received RDMA control messages */
if (have_ctrl || eager) {
ret = send_progress(req);
if (ret == -FI_EAGAIN) {
/* Add to pending reqs queue */
nccl_net_ofi_mutex_lock(&ep->pending_reqs_lock);
ep->pending_reqs_queue->push_back(req);
nccl_net_ofi_mutex_unlock(&ep->pending_reqs_lock);
ret = 0;
NCCL_OFI_TRACE_PENDING_INSERT(req);
} else if (OFI_UNLIKELY(ret != 0)) {
/* TODO: Remove req from message buffer */
ret = -ENOTSUP;
goto error;
}
}
/* Return request to NCCL */
*base_req = &req->base;
/* Increment next_msg_seq_num for next call */
s_comm->next_msg_seq_num = (s_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK;
goto exit;
free_req:
error:
if (req)
req->free(req, false);
*base_req = NULL;
exit:
return ret;
}