in src/nccl_ofi_rdma.cpp [3259:3362]
static inline int insert_send_ctrl_req(
nccl_net_ofi_rdma_recv_comm_t *r_comm,
nccl_net_ofi_rdma_device_t *device,
int dev_id, uint16_t msg_seq_num, void *buff,
size_t size,
nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
nccl_net_ofi_rdma_req_t *recv_req,
bool recv_completion_optional)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
nccl_net_ofi_rdma_domain_t *domain = rdma_endpoint_get_domain(ep);
assert(domain != NULL);
nccl_net_ofi_scheduler_t *scheduler = domain->scheduler;
nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl);
if (OFI_UNLIKELY(send_ctrl_req == NULL)) {
NCCL_OFI_WARN("Unable to get NCCL OFI send control request for device %d",
dev_id);
return -EINVAL;
}
send_ctrl_req->comm = &r_comm->base.base;
send_ctrl_req->dev_id = dev_id;
send_ctrl_req->type = NCCL_OFI_RDMA_SEND_CTRL;
send_ctrl_req->free = free_send_ctrl_req;
send_ctrl_req->msg_seq_num = msg_seq_num;
rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req);
if (ep->num_control_rails > 1) {
size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys);
send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, ctrl_msg_len, ep->num_control_rails);
if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) {
return -EINVAL;
} else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) {
NCCL_OFI_WARN(
"Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got "
"%zu",
size,
send_ctrl_data->ctrl_schedule->num_xfer_infos);
return -EINVAL;
}
} else {
send_ctrl_data->ctrl_schedule = NULL;
}
send_ctrl_data->recv_req = recv_req;
send_ctrl_data->ctrl_fl_elem = NULL;
/*
* Allocate RDMA control buffer which transfers the RDMA write buffer
* information to sender.
*/
send_ctrl_data->ctrl_fl_elem = nccl_ofi_freelist_entry_alloc
(r_comm->ctrl_buff_fl);
if (send_ctrl_data->ctrl_fl_elem == NULL) {
NCCL_OFI_WARN("Call to nccl_ofi_freelist_entry_alloc failed");
return -ENOMEM;
}
if (!virt_addr_mr) {
/*
* TODO: Here, we have to compute the offset of
* NCCL's buffer relative to the registration.
*/
NCCL_OFI_WARN("virt_addr_mr mode is not supported yet!");
return -ENOTSUP;
}
nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = rdma_send_ctrl_get_msg(send_ctrl_data);
/* If early completion is turned on, CTRL msg type will be NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION to influence send() behavior */
ctrl_msg->type = recv_completion_optional ? NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION : NCCL_OFI_RDMA_MSG_CTRL;
ctrl_msg->remote_comm_id = r_comm->remote_comm_id;
ctrl_msg->msg_seq_num = msg_seq_num;
ctrl_msg->buff_addr = (uint64_t)buff;
ctrl_msg->buff_len = size;
uint16_t rail_id = 0;
for (; rail_id < r_comm->num_rails; rail_id++) {
uint64_t rkey = fi_mr_key(buff_mr_handle->mr[rail_id]);
if (rkey == FI_KEY_NOTAVAIL) {
NCCL_OFI_WARN("RDMA write buffers should be pre-registered");
return -ENOENT;
}
if (ep->use_long_rkeys) {
ctrl_msg->long_buff_mr_key[rail_id] = rkey;
} else {
if (rkey > (1ULL << (NCCL_NET_OFI_CTRL_MSG_SHORT_KEY_SIZE * 8)) - 1) {
NCCL_OFI_WARN("Libfabric returned rkey larger than declared rkey size: %" PRIu64,
rkey);
return -ENOTSUP;
}
ctrl_msg->short_buff_mr_key[rail_id] = rkey;
}
}
rdma_req_recv_data_t *recv_data = get_recv_data(recv_req);
recv_data->send_ctrl_req = send_ctrl_req;
return 0;
}