static int sendrecv_endpoint_listen()

in src/nccl_ofi_sendrecv.cpp [1744:1820]


static int sendrecv_endpoint_listen(nccl_net_ofi_ep_t *base_ep,
				    nccl_net_ofi_conn_handle_t *handle,
				    nccl_net_ofi_listen_comm_t **listen_comm)
{
	char *local_ep_name = NULL;
	fi_addr_t local_ep_addr;
	nccl_net_ofi_sendrecv_listen_comm_t *l_comm = NULL;
	uint64_t tag;
	int dev_id = 0;
	int num_addrs;
	nccl_net_ofi_sendrecv_ep_t *ep =
		(nccl_net_ofi_sendrecv_ep_t *)base_ep;

	/* Retrieve and validate device */
	nccl_net_ofi_sendrecv_device_t *device = sendrecv_endpoint_get_device(ep);
	if (OFI_UNLIKELY(device == NULL)) {
		NCCL_OFI_WARN("Invalid device provided");
		return -EINVAL;
	}

	dev_id = device->base.dev_id;

	/* Zero-out the handle */
	memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t));

	/* Increase tag ID */
	if (ep->tag + 1 >=
	    device->max_tag) {
		NCCL_OFI_WARN("Cannot open more connection for device ID %d."
			      " Maximum is %ld",
			      dev_id, device->max_tag);
		return -ENOSPC;
	}
	tag = ++ep->tag;

	/* Build handle */
	local_ep_name = sendrecv_get_local_address(ep->ofi_ep);
	if (local_ep_name == NULL) {
		return -EINVAL;
	}

	memcpy(handle->ep_name, local_ep_name, MAX_EP_ADDR);
	handle->comm_id = (uint32_t)tag;

	/* Insert local EP address to AV. This will be used to issue local read operations */
	num_addrs = fi_av_insert(ep->av, (void *)local_ep_name, 1,
				 &local_ep_addr, 0, NULL);

	/* Only 1 address should be inserted into the AV */
	if (OFI_UNLIKELY(num_addrs != 1)) {
		NCCL_OFI_WARN("Unable to insert remote address into address vector for device %d.", dev_id);
		return -EINVAL;
	}

	/* Build listen_comm */
	l_comm = (nccl_net_ofi_sendrecv_listen_comm_t *)calloc(
		1,
		sizeof(nccl_net_ofi_sendrecv_listen_comm_t));
	if (OFI_UNLIKELY(l_comm == NULL)) {
		NCCL_OFI_WARN("Couldn't allocate listen_comm for dev %d", dev_id);
		return -ENOMEM;
	}

	/* Initialize listen communicator */
	l_comm->base.base.type = NCCL_NET_OFI_LISTEN_COMM;
	l_comm->base.base.ep = base_ep;
	l_comm->base.base.dev_id = dev_id;
	l_comm->base.accept = sendrecv_listen_comm_accept;
	l_comm->base.close = sendrecv_listen_comm_close;
	l_comm->tag = tag;
	l_comm->local_ep = ep->ofi_ep;
	l_comm->accepted = false;
	l_comm->local_ep_addr = local_ep_addr;

	*listen_comm = (nccl_net_ofi_listen_comm_t *)l_comm;
	return 0;
}