static inline int create_send_comm()

in src/nccl_ofi_rdma.cpp [6506:6656]


static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
				   nccl_net_ofi_rdma_ep_t *ep,
				   nccl_net_ofi_rdma_send_comm_t **s_comm)
{
	int ret = 0;
	size_t comm_id = 0;
	fi_addr_t remote_addr;
	nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL;
	int num_rails = ep->num_rails;
	int num_control_rails = ep->num_control_rails;
	uint16_t rail_id = 0;
	nccl_net_ofi_ep_rail_t *first_control_rail = rdma_endpoint_get_control_rail(ep, 0);
	nccl_net_ofi_rdma_send_comm_rail_t *first_comm_control_rail;

	*s_comm = NULL;

	/* Retrieve and validate device */
	nccl_net_ofi_rdma_device_t *device = rdma_endpoint_get_device(ep);
	if (OFI_UNLIKELY(device == NULL)) {
		NCCL_OFI_WARN("Error accessing device");
		return -EINVAL;
	}
	int dev_id = device->base.dev_id;

	/* Allocate and initialize send_comm */
	ret_s_comm = calloc_rdma_send_comm(num_rails, num_control_rails);
	if (OFI_UNLIKELY(ret_s_comm == NULL)) {
		NCCL_OFI_WARN("Couldn't allocate send comm object for dev %d", dev_id);
		return -ENOMEM;
	}

	ret = nccl_net_ofi_mutex_init(&ret_s_comm->ctrl_recv_lock, NULL);
	if (ret != 0) {
		free_rdma_send_comm(ret_s_comm);
		return ret;
	}

	ret_s_comm->base.base.type = NCCL_NET_OFI_SEND_COMM;
	ret_s_comm->base.base.ep = &ep->base;
	ret_s_comm->base.base.dev_id = dev_id;
	ret_s_comm->base.regMr = reg_mr_send_comm;
	ret_s_comm->base.deregMr = dereg_mr_send_comm;
	ret_s_comm->base.send = send;
	ret_s_comm->base.close = send_close_deferred;
	ret_s_comm->base.write = rma_write;
	ret_s_comm->base.write_inline = rma_write_inline;

	ret_s_comm->comm_active = true;
	ret_s_comm->next_msg_seq_num = 0;

	ret_s_comm->received_close_message = false;
	ret_s_comm->n_ctrl_received = 0;
	ret_s_comm->n_ctrl_expected = 0;

	/* Store communicator ID from handle in communicator */
	if (OFI_UNLIKELY(handle->comm_id >= device->num_comm_ids)) {
		NCCL_OFI_WARN("Received an invalid communicator ID %" PRIu32 " for device %d", handle->comm_id,
			      dev_id);
		ret = -EINVAL;
		goto error;
	}
	ret_s_comm->remote_comm_id = handle->comm_id;

	/* Allocate send communicator ID */
	comm_id = device->comm_idpool->allocate_id();
	if (OFI_UNLIKELY(comm_id == FI_KEY_NOTAVAIL)) {
		ret_s_comm->local_comm_id = COMM_ID_INVALID;
		ret = -ENOMEM;
		goto error;
	}
	ret_s_comm->local_comm_id = (uint32_t)comm_id;

	/* Add ourselves to ep's lookup array */
	rdma_device_set_comm(device, ret_s_comm->local_comm_id, &ret_s_comm->base.base);

	/* Allocate communicator rails array */
	ret_s_comm->num_rails = num_rails;
	ret_s_comm->num_control_rails = num_control_rails;

	/* Insert remote name into AV of first rail */
	ret = fi_av_insert(first_control_rail->av,
			   (void *)handle->ep_name, 1,
			   &remote_addr, 0, NULL);
	if (OFI_UNLIKELY(ret != 1)) {
		NCCL_OFI_WARN("Unable to insert remote address into address vector for device %d. RC: %d",
			      dev_id, ret);
		ret = -EINVAL;
		goto error;
	}

	/* Store remote address of first rail in communicator */
	first_comm_control_rail = &ret_s_comm->control_rails[0];
	first_comm_control_rail->remote_addr = remote_addr;

	/* Store local libfabric endpoint of control rail */
	first_comm_control_rail->local_ep = first_control_rail->ofi_ep;
	ret_s_comm->num_init_control_rails = 1;

	/* Allocate request free list */
	ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16,
				     NCCL_OFI_MAX_SEND_REQUESTS,
				     rdma_fl_req_entry_init, rdma_fl_req_entry_fini,
				     &ret_s_comm->nccl_ofi_reqs_fl);
	if (OFI_UNLIKELY(ret != 0)) {
		NCCL_OFI_WARN("Could not allocate NCCL OFI request free list for dev %d rail %d",
			      dev_id, rail_id);
		goto error;
	}

	/* Allocate connect message, will be returned after send completion */
	ret_s_comm->conn_msg = nccl_ofi_freelist_entry_alloc(ep->conn_msg_fl);
	if (ret_s_comm->conn_msg == NULL) {
		NCCL_OFI_WARN("Failed to allocate conn_msg buffer");
		return -ENOMEM;
	}

	prepare_send_connect_message(ep, dev_id, ret_s_comm->local_comm_id, ret_s_comm->remote_comm_id, handle,
				     (nccl_ofi_rdma_connection_info_t *)ret_s_comm->conn_msg->ptr);

	/* Allocate message buffer */
	ret_s_comm->msgbuff = nccl_ofi_msgbuff_init(NCCL_OFI_RDMA_MSGBUFF_SIZE, NCCL_OFI_RDMA_SEQ_BITS);
	if (!ret_s_comm->msgbuff) {
		NCCL_OFI_WARN("Failed to allocate and initialize message buffer");
		ret = -ENOMEM;
		goto error;
	}

#if HAVE_NVTX_TRACING && NCCL_OFI_NVTX_TRACE_PER_COMM
	for (int i = 0; i < NCCL_OFI_N_NVTX_DOMAIN_PER_COMM; ++i)
	{
		/* Create nvtx domain */
		char name[64];
		snprintf(name, 64, "aws-ofi-nccl s_comm %p_%d", ret_s_comm, i);
		ret_s_comm->nvtx_domain[i] = nvtxDomainCreateA(name);
	}
#endif
	*s_comm = ret_s_comm;
	return ret;


 error:
	if (ret_s_comm) {
		if (COMM_ID_INVALID != ret_s_comm->local_comm_id) {
			device->comm_idpool->free_id(ret_s_comm->local_comm_id);
		}
		nccl_net_ofi_mutex_destroy(&ret_s_comm->ctrl_recv_lock);
		free_rdma_send_comm(ret_s_comm);
	}

	return ret;
}