ssize_t psmx2_atomic_readwritev_generic()

in prov/psm2/src/psmx2_atomic.c [1178:1349]


ssize_t psmx2_atomic_readwritev_generic(struct fid_ep *ep,
					const struct fi_ioc *iov,
					void **desc, size_t count,
					struct fi_ioc *resultv,
					void **result_desc,
					size_t result_count,
					fi_addr_t dest_addr,
					uint64_t addr, uint64_t key,
					enum fi_datatype datatype,
					enum fi_op op, void *context,
					uint64_t flags)
{
	struct psmx2_fid_ep *ep_priv;
	struct psmx2_fid_av *av;
	struct psmx2_am_request *req;
	psm2_amarg_t args[8];
	psm2_epaddr_t psm2_epaddr;
	psm2_epid_t psm2_epid;
	int am_flags = PSM2_AM_FLAG_ASYNC;
	int chunk_size;
	size_t len, result_len, iov_size;
	uint8_t *buf, *result;
	void *desc0, *result_desc0;
	int err;

	ep_priv = container_of(ep, struct psmx2_fid_ep, ep);

	if (flags & FI_TRIGGER)
		return psmx2_trigger_queue_atomic_readwritev(ep, iov, desc,
							     count, resultv,
							     result_desc,
							     result_count,
							     dest_addr, addr,
							     key, datatype, op,
							     context, flags);

	assert((iov && count) || op == FI_ATOMIC_READ);
	assert(resultv);
	assert(result_count);
	assert((int)datatype >= 0 && (int)datatype < FI_DATATYPE_LAST);
	assert((int)op >= 0 && (int)op < FI_ATOMIC_OP_LAST);

	if (iov) {
		while (count && !iov[count-1].count)
			count--;
	}

	while (result_count && !resultv[result_count-1].count)
		result_count--;

	result_len = psmx2_ioc_size(resultv, result_count, datatype);

	if (op != FI_ATOMIC_READ) {
		buf = iov[0].addr; /* as default for count == 1 */
		len = psmx2_ioc_size(iov, count, datatype);
		desc0 = desc ? desc[0] : NULL;
	} else {
		buf = NULL;
		len = result_len;
		desc0 = NULL;
	}

	assert(result_len >= len);

	av = ep_priv->av;
	assert(av);

	psm2_epaddr = psmx2_av_translate_addr(av, ep_priv->tx, dest_addr, av->type);
	psm2_epaddr_to_epid(psm2_epaddr, &psm2_epid);

	if (psm2_epid == ep_priv->tx->psm2_epid) {
		if (buf && count > 1) {
			buf = malloc(len);
			psmx2_ioc_read(iov, count, datatype, buf, len);
			desc0 = NULL;
		}

		if (result_count > 1) {
			result = malloc(len);
			if (!result) {
				if (buf && count > 1)
					free(buf);
				return -FI_ENOMEM;
			}
			result_desc0 = result_desc ? result_desc[0] : NULL;
		} else {
			result = resultv[0].addr;
			result_desc0 = NULL;
		}

		err = psmx2_atomic_self(PSMX2_AM_REQ_ATOMIC_READWRITE, ep_priv,
					buf, len / ofi_datatype_size(datatype),
					desc0, NULL, NULL, result, result_desc0,
					addr, key, datatype, op, context, flags);

		if (result_count > 1) {
			psmx2_ioc_write(resultv, result_count, datatype, result, len);
			free(result);
		}

		if (buf && count > 1)
			free(buf);

		return err;
	}

	chunk_size = ep_priv->tx->psm2_am_param.max_request_short;
	if (len > chunk_size)
		return -FI_EMSGSIZE;

	iov_size = result_count > 1 ? result_count * sizeof(struct fi_ioc) : 0;

	req = psmx2_am_request_alloc(ep_priv->tx);
	if (!req)
		return -FI_ENOMEM;

	if (((flags & FI_INJECT) || count > 1) && op != FI_ATOMIC_READ) {
		req->tmpbuf = malloc(iov_size + len);
		if (!req->tmpbuf) {
			psmx2_am_request_free(ep_priv->tx, req);
			return -FI_ENOMEM;
		}

		buf = (uint8_t *)req->tmpbuf + iov_size;
		psmx2_ioc_read(iov, count, datatype, buf, len);
	} else {
		req->tmpbuf = malloc(iov_size);
		if (!req->tmpbuf) {
			psmx2_am_request_free(ep_priv->tx, req);
			return -FI_ENOMEM;
		}
	}

	req->ioc = req->tmpbuf;
	if (iov_size) {
		memcpy(req->ioc, resultv, iov_size);
		req->atomic.iov_count = result_count;
		req->atomic.result = NULL;
	} else {
		req->atomic.buf = buf;
		req->atomic.result = resultv[0].addr;
	}

	req->no_event = (flags & PSMX2_NO_COMPLETION) ||
			(ep_priv->send_selective_completion && !(flags & FI_COMPLETION));

	req->op = PSMX2_AM_REQ_ATOMIC_READWRITE;
	req->atomic.buf = (void *)buf;
	req->atomic.len = len;
	req->atomic.addr = addr;
	req->atomic.key = key;
	req->atomic.context = context;
	req->atomic.datatype = datatype;
	req->ep = ep_priv;
	if (op == FI_ATOMIC_READ)
		req->cq_flags = FI_READ | FI_ATOMIC;
	else
		req->cq_flags = FI_WRITE | FI_ATOMIC;

	args[0].u32w0 = PSMX2_AM_REQ_ATOMIC_READWRITE;
	args[0].u32w1 = len / ofi_datatype_size(datatype);
	args[1].u64 = (uint64_t)(uintptr_t)req;
	args[2].u64 = addr;
	args[3].u64 = key;
	args[4].u32w0 = datatype;
	args[4].u32w1 = op;
	psm2_am_request_short(psm2_epaddr,
			      PSMX2_AM_ATOMIC_HANDLER, args, 5,
			      (void *)buf, (buf?len:0), am_flags, NULL, NULL);
	psmx2_am_poll(ep_priv->tx);
	return 0;
}