int psmx2_am_atomic_handler()

in prov/psm2/src/psmx2_atomic.c [413:650]


int psmx2_am_atomic_handler(psm2_am_token_t token,
				psm2_amarg_t *args, int nargs, void *src,
				uint32_t len, void *hctx)
{
	psm2_amarg_t rep_args[8];
	int count;
	uint8_t *addr;
	uint64_t key;
	int datatype, op;
	int err = 0;
	int op_error = 0;
	struct psmx2_am_request *req;
	struct psmx2_cq_event *event;
	struct psmx2_fid_mr *mr;
	struct psmx2_fid_cntr *cntr = NULL;
	struct psmx2_fid_cntr *mr_cntr = NULL;
	void *tmp_buf;
	psm2_epaddr_t epaddr;
	int cmd;
	struct psmx2_trx_ctxt *rx;

	psm2_am_get_source(token, &epaddr);
	cmd = PSMX2_AM_GET_OP(args[0].u32w0);

	switch (cmd) {
	case PSMX2_AM_REQ_ATOMIC_WRITE:
		rx = (struct psmx2_trx_ctxt *)hctx;
		count = args[0].u32w1;
		addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		datatype = args[4].u32w0;
		op = args[4].u32w1;
		assert(len == ofi_datatype_size(datatype) * count);

		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) :
			-FI_EINVAL;

		if (!op_error) {
			addr += mr->offset;
			psmx2_atomic_do_write(addr, src, datatype, op, count);

			if (rx->ep->caps & FI_RMA_EVENT) {
				cntr = rx->ep->remote_write_cntr;
				mr_cntr = mr->cntr;

				if (cntr)
					psmx2_cntr_inc(cntr, 0);

				if (mr_cntr && mr_cntr != cntr)
					psmx2_cntr_inc(mr_cntr, 0);
			}
		}

		rep_args[0].u32w0 = PSMX2_AM_REP_ATOMIC_WRITE;
		rep_args[0].u32w1 = op_error;
		rep_args[1].u64 = args[1].u64;
		err = psm2_am_reply_short(token, PSMX2_AM_ATOMIC_HANDLER,
					  rep_args, 2, NULL, 0, 0,
					  NULL, NULL );
		break;

	case PSMX2_AM_REQ_ATOMIC_READWRITE:
		rx = (struct psmx2_trx_ctxt *)hctx;
		count = args[0].u32w1;
		addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		datatype = args[4].u32w0;
		op = args[4].u32w1;

		if (op == FI_ATOMIC_READ)
			len = ofi_datatype_size(datatype) * count;

		assert(len == ofi_datatype_size(datatype) * count);

		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)addr, len,
					  FI_REMOTE_READ|FI_REMOTE_WRITE) :
			-FI_EINVAL;

		if (!op_error) {
			addr += mr->offset;
			tmp_buf = malloc(len);
			if (tmp_buf)
				psmx2_atomic_do_readwrite(addr, src, tmp_buf,
							  datatype, op, count);
			else
				op_error = -FI_ENOMEM;

			if (rx->ep->caps & FI_RMA_EVENT) {
				if (op == FI_ATOMIC_READ) {
					cntr = rx->ep->remote_read_cntr;
				} else {
					cntr = rx->ep->remote_write_cntr;
					mr_cntr = mr->cntr;
				}

				if (cntr)
					psmx2_cntr_inc(cntr, 0);

				if (mr_cntr && mr_cntr != cntr)
					psmx2_cntr_inc(mr_cntr, 0);
			}
		} else {
			tmp_buf = NULL;
		}

		rep_args[0].u32w0 = PSMX2_AM_REP_ATOMIC_READWRITE;
		rep_args[0].u32w1 = op_error;
		rep_args[1].u64 = args[1].u64;
		err = psm2_am_reply_short(token, PSMX2_AM_ATOMIC_HANDLER,
					  rep_args, 2, tmp_buf,
					  (tmp_buf ? len : 0),
					  0, free, tmp_buf );
		break;

	case PSMX2_AM_REQ_ATOMIC_COMPWRITE:
		rx = (struct psmx2_trx_ctxt *)hctx;
		count = args[0].u32w1;
		addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		datatype = args[4].u32w0;
		op = args[4].u32w1;
		len /= 2;
		assert(len == ofi_datatype_size(datatype) * count);

		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)addr, len,
					  FI_REMOTE_READ|FI_REMOTE_WRITE) :
			-FI_EINVAL;

		if (!op_error) {
			addr += mr->offset;
			tmp_buf = malloc(len);
			if (tmp_buf)
				psmx2_atomic_do_compwrite(addr, src, (uint8_t *)src + len,
							  tmp_buf, datatype,
							  op, count);
			else
				op_error = -FI_ENOMEM;

			if (rx->ep->caps & FI_RMA_EVENT) {
				cntr = rx->ep->remote_write_cntr;
				mr_cntr = mr->cntr;

				if (cntr)
					psmx2_cntr_inc(cntr, 0);

				if (mr_cntr && mr_cntr != cntr)
					psmx2_cntr_inc(mr_cntr, 0);
			}
		} else {
			tmp_buf = NULL;
		}

		rep_args[0].u32w0 = PSMX2_AM_REP_ATOMIC_READWRITE;
		rep_args[0].u32w1 = op_error;
		rep_args[1].u64 = args[1].u64;
		err = psm2_am_reply_short(token, PSMX2_AM_ATOMIC_HANDLER,
					  rep_args, 2, tmp_buf,
					  (tmp_buf ? len : 0),
					  0, free, tmp_buf );
		break;

	case PSMX2_AM_REP_ATOMIC_WRITE:
		req = (struct psmx2_am_request *)(uintptr_t)args[1].u64;
		op_error = (int)args[0].u32w1;
		assert(req->op == PSMX2_AM_REQ_ATOMIC_WRITE);
		if (req->ep->send_cq && (!req->no_event || op_error)) {
			event = psmx2_cq_create_event(
					req->ep->send_cq,
					req->atomic.context,
					req->atomic.buf,
					req->cq_flags,
					req->atomic.len,
					0, /* data */
					0, /* tag */
					0, /* olen */
					op_error);
			if (event)
				psmx2_cq_enqueue_event(req->ep->send_cq, event);
			else
				err = -FI_ENOMEM;
		}

		if (req->ep->write_cntr)
			psmx2_cntr_inc(req->ep->write_cntr, op_error);

		free(req->tmpbuf);
		psmx2_am_request_free(req->ep->tx, req);
		break;

	case PSMX2_AM_REP_ATOMIC_READWRITE:
	case PSMX2_AM_REP_ATOMIC_COMPWRITE:
		req = (struct psmx2_am_request *)(uintptr_t)args[1].u64;
		op_error = (int)args[0].u32w1;
		assert(op_error || req->atomic.len == len);

		if (!op_error) {
			if (req->atomic.result)
				memcpy(req->atomic.result, src, len);
			else
				psmx2_ioc_write(req->ioc, req->atomic.iov_count,
						req->atomic.datatype, src, len);
		}

		if (req->ep->send_cq && (!req->no_event || op_error)) {
			event = psmx2_cq_create_event(
					req->ep->send_cq,
					req->atomic.context,
					req->atomic.buf,
					req->cq_flags,
					req->atomic.len,
					0, /* data */
					0, /* tag */
					0, /* olen */
					op_error);
			if (event)
				psmx2_cq_enqueue_event(req->ep->send_cq, event);
			else
				err = -FI_ENOMEM;
		}

		if (req->ep->read_cntr)
			psmx2_cntr_inc(req->ep->read_cntr, op_error);

		free(req->tmpbuf);
		psmx2_am_request_free(req->ep->tx, req);
		break;

	default:
		err = -FI_EINVAL;
	}
	return err;
}