ssize_t psmx2_atomic_compwritev_generic()

in prov/psm2/src/psmx2_atomic.c [1559:1753]


ssize_t psmx2_atomic_compwritev_generic(struct fid_ep *ep,
					const struct fi_ioc *iov,
					void **desc, size_t count,
					const struct fi_ioc *comparev,
					void **compare_desc,
					size_t compare_count,
					struct fi_ioc *resultv,
					void **result_desc,
					size_t result_count,
					fi_addr_t dest_addr,
					uint64_t addr, uint64_t key,
					enum fi_datatype datatype,
					enum fi_op op, void *context,
					uint64_t flags)
{
	struct psmx2_fid_ep *ep_priv;
	struct psmx2_fid_av *av;
	struct psmx2_am_request *req;
	psm2_amarg_t args[8];
	psm2_epaddr_t psm2_epaddr;
	psm2_epid_t psm2_epid;
	int am_flags = PSM2_AM_FLAG_ASYNC;
	int chunk_size;
	size_t len, iov_size;
	uint8_t *buf, *compare, *result;
	void *desc0, *compare_desc0, *result_desc0;
	int err;

	ep_priv = container_of(ep, struct psmx2_fid_ep, ep);

	if (flags & FI_TRIGGER)
		return psmx2_trigger_queue_atomic_compwritev(ep, iov, desc,
							     count, comparev,
							     compare_desc,
							     compare_count,
							     resultv,
							     result_desc,
							     result_count,
							     dest_addr, addr,
							     key, datatype, op,
							     context, flags);

	assert(iov);
	assert(count);
	assert(comparev);
	assert(compare_count);
	assert(resultv);
	assert(result_count);
	assert((int)datatype >= 0 && (int)datatype < FI_DATATYPE_LAST);
	assert((int)op >= 0 && (int)op < FI_ATOMIC_OP_LAST);

	while (count && !iov[count-1].count)
		count--;

	while (compare_count && !comparev[compare_count-1].count)
		compare_count--;

	while (result_count && !resultv[result_count-1].count)
		result_count--;

	len = psmx2_ioc_size(iov, count, datatype);

	assert(psmx2_ioc_size(comparev, compare_count, datatype) >= len);
	assert(psmx2_ioc_size(resultv, result_count, datatype) >= len);

	av = ep_priv->av;
	assert(av);

	psm2_epaddr = psmx2_av_translate_addr(av, ep_priv->tx, dest_addr, av->type);
	psm2_epaddr_to_epid(psm2_epaddr, &psm2_epid);

	if (psm2_epid == ep_priv->tx->psm2_epid) {
		if (count > 1) {
			buf = malloc(len);
			if (!buf)
				return -FI_ENOMEM;
			psmx2_ioc_read(iov, count, datatype, buf, len);
			desc0 = NULL;
		} else {
			buf = iov[0].addr;
			desc0 = desc ? desc[0] : NULL;
		}

		if (compare_count > 1) {
			compare = malloc(len);
			if (!compare) {
				if (count > 1)
					free(buf);
				return -FI_ENOMEM;
			}
			psmx2_ioc_read(comparev, compare_count, datatype, compare, len);
			compare_desc0 = NULL;
		} else {
			compare = comparev[0].addr;
			compare_desc0 = compare_desc ? compare_desc[0] : NULL;
		}

		if (result_count > 1) {
			result = malloc(len);
			if (!result) {
				if (compare_count > 1)
					free(compare);
				if (count > 1)
					free(buf);
				return -FI_ENOMEM;
			}
			result_desc0 = NULL;
		} else {
			result = resultv[0].addr;
			result_desc0 = result_desc ? result_desc[0] : NULL;
		}

		err = psmx2_atomic_self(PSMX2_AM_REQ_ATOMIC_COMPWRITE, ep_priv,
					buf, len / ofi_datatype_size(datatype), desc0,
					compare, compare_desc0, result, result_desc0,
					addr, key, datatype, op, context, flags);

		if (result_count > 1) {
			psmx2_ioc_write(resultv, result_count, datatype, result, len);
			free(result);
		}

		if (compare_count > 1)
			free(compare);

		if (count > 1)
			free(buf);

		return err;
	}

	chunk_size = ep_priv->tx->psm2_am_param.max_request_short;
	if (len * 2 > chunk_size)
		return -FI_EMSGSIZE;

	iov_size = result_count > 1 ? result_count * sizeof(struct fi_ioc) : 0;

	req = psmx2_am_request_alloc(ep_priv->tx);
	if (!req)
		return -FI_ENOMEM;

	if ((flags & FI_INJECT) || count > 1 || compare_count > 1 ||
	    (uintptr_t)comparev[0].addr != (uintptr_t)iov[0].addr + len) {
		req->tmpbuf = malloc(iov_size + len + len);
		if (!req->tmpbuf) {
			psmx2_am_request_free(ep_priv->tx, req);
			return -FI_ENOMEM;
		}
		buf = (uint8_t *)req->tmpbuf + iov_size;
		psmx2_ioc_read(iov, count, datatype, buf, len);
		psmx2_ioc_read(comparev, compare_count, datatype, buf + len, len);
	} else {
		req->tmpbuf = malloc(iov_size);
		if (!req->tmpbuf) {
			psmx2_am_request_free(ep_priv->tx, req);
			return -FI_ENOMEM;
		}
		buf = iov[0].addr;
	}

	req->ioc = req->tmpbuf;
	if (iov_size) {
		memcpy(req->ioc, resultv, iov_size);
		req->atomic.iov_count = result_count;
		req->atomic.result = NULL;
	} else {
		req->atomic.buf = buf;
		req->atomic.result = resultv[0].addr;
	}

	req->no_event = (flags & PSMX2_NO_COMPLETION) ||
			(ep_priv->send_selective_completion && !(flags & FI_COMPLETION));

	req->op = PSMX2_AM_REQ_ATOMIC_COMPWRITE;
	req->atomic.len = len;
	req->atomic.addr = addr;
	req->atomic.key = key;
	req->atomic.context = context;
	req->atomic.datatype = datatype;
	req->ep = ep_priv;
	req->cq_flags = FI_WRITE | FI_ATOMIC;

	args[0].u32w0 = PSMX2_AM_REQ_ATOMIC_COMPWRITE;
	args[0].u32w1 = len / ofi_datatype_size(datatype);
	args[1].u64 = (uint64_t)(uintptr_t)req;
	args[2].u64 = addr;
	args[3].u64 = key;
	args[4].u32w0 = datatype;
	args[4].u32w1 = op;
	psm2_am_request_short(psm2_epaddr,
			      PSMX2_AM_ATOMIC_HANDLER, args, 5,
			      buf, len * 2, am_flags, NULL, NULL);
	psmx2_am_poll(ep_priv->tx);
	return 0;
}