static inline int insert_send_ctrl_req()

in src/nccl_ofi_rdma.cpp [3259:3362]


static inline int insert_send_ctrl_req(
				nccl_net_ofi_rdma_recv_comm_t *r_comm,
				nccl_net_ofi_rdma_device_t *device,
				int dev_id, uint16_t msg_seq_num, void *buff,
				size_t size,
				nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
				nccl_net_ofi_rdma_req_t *recv_req,
				bool recv_completion_optional)
{
	nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
	nccl_net_ofi_rdma_domain_t *domain = rdma_endpoint_get_domain(ep);
	assert(domain != NULL);
	nccl_net_ofi_scheduler_t *scheduler = domain->scheduler;
	nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl);
	if (OFI_UNLIKELY(send_ctrl_req == NULL)) {
		NCCL_OFI_WARN("Unable to get NCCL OFI send control request for device %d",
						dev_id);
		return -EINVAL;
	}

	send_ctrl_req->comm = &r_comm->base.base;
	send_ctrl_req->dev_id = dev_id;
	send_ctrl_req->type = NCCL_OFI_RDMA_SEND_CTRL;
	send_ctrl_req->free = free_send_ctrl_req;
	send_ctrl_req->msg_seq_num = msg_seq_num;

	rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req);

	if (ep->num_control_rails > 1) {
		size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys);
		send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, ctrl_msg_len, ep->num_control_rails);

		if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) {
			return -EINVAL;
		} else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) {
			NCCL_OFI_WARN(
				"Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got "
				"%zu",
				size,
				send_ctrl_data->ctrl_schedule->num_xfer_infos);
			return -EINVAL;
		}
	} else {
		send_ctrl_data->ctrl_schedule = NULL;
	}

	send_ctrl_data->recv_req = recv_req;
	send_ctrl_data->ctrl_fl_elem = NULL;

	/*
	 * Allocate RDMA control buffer which transfers the RDMA write buffer
	 * information to sender.
	 */
	send_ctrl_data->ctrl_fl_elem = nccl_ofi_freelist_entry_alloc
					(r_comm->ctrl_buff_fl);
	if (send_ctrl_data->ctrl_fl_elem == NULL) {
		NCCL_OFI_WARN("Call to nccl_ofi_freelist_entry_alloc failed");
		return -ENOMEM;
	}

	if (!virt_addr_mr) {
		/*
		 * TODO: Here, we have to compute the offset of
		 * NCCL's buffer relative to the registration.
		 */
		NCCL_OFI_WARN("virt_addr_mr mode is not supported yet!");
		return -ENOTSUP;
	}

	nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = rdma_send_ctrl_get_msg(send_ctrl_data);

	/* If early completion is turned on, CTRL msg type will be NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION to influence send() behavior */
	ctrl_msg->type = recv_completion_optional ? NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION : NCCL_OFI_RDMA_MSG_CTRL;
	ctrl_msg->remote_comm_id = r_comm->remote_comm_id;
	ctrl_msg->msg_seq_num = msg_seq_num;
	ctrl_msg->buff_addr = (uint64_t)buff;
	ctrl_msg->buff_len = size;

	uint16_t rail_id = 0;
	for (; rail_id < r_comm->num_rails; rail_id++) {
		uint64_t rkey = fi_mr_key(buff_mr_handle->mr[rail_id]);

		if (rkey == FI_KEY_NOTAVAIL) {
			NCCL_OFI_WARN("RDMA write buffers should be pre-registered");
			return -ENOENT;
		}

		if (ep->use_long_rkeys) {
			ctrl_msg->long_buff_mr_key[rail_id] = rkey;
		} else {
			if (rkey > (1ULL << (NCCL_NET_OFI_CTRL_MSG_SHORT_KEY_SIZE * 8)) - 1) {
				NCCL_OFI_WARN("Libfabric returned rkey larger than declared rkey size: %" PRIu64,
					      rkey);
				return -ENOTSUP;
			}
			ctrl_msg->short_buff_mr_key[rail_id] = rkey;
		}
	}

	rdma_req_recv_data_t *recv_data = get_recv_data(recv_req);
	recv_data->send_ctrl_req = send_ctrl_req;

	return 0;
}