int psmx2_am_rma_handler()

in prov/psm2/src/psmx2_rma.c [103:418]


int psmx2_am_rma_handler(psm2_am_token_t token, psm2_amarg_t *args,
		int nargs, void *src, uint32_t len,
		void *hctx)
{
	psm2_amarg_t rep_args[8];
	uint8_t *rma_addr;
	ssize_t rma_len;
	uint64_t key;
	int err = 0;
	int op_error = 0;
	int cmd, eom, has_data;
	struct psmx2_am_request *req;
	struct psmx2_cq_event *event;
	uint64_t offset;
	struct psmx2_fid_mr *mr;
	psm2_epaddr_t epaddr;
	struct psmx2_trx_ctxt *rx;

#if HAVE_PSM2_MQ_FP_MSG
	psm2_mq_req_t psm2_req;
	psm2_mq_tag_t psm2_tag, psm2_tagsel;
#endif

	psm2_am_get_source(token, &epaddr);
	cmd = PSMX2_AM_GET_OP(args[0].u32w0);
	eom = args[0].u32w0 & PSMX2_AM_EOM;
	has_data = args[0].u32w0 & PSMX2_AM_DATA;

	switch (cmd) {
	case PSMX2_AM_REQ_WRITE:
		rx = (struct psmx2_trx_ctxt *)hctx;
		rma_len = args[0].u32w1;
		rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)rma_addr, len, FI_REMOTE_WRITE) :
			-FI_EINVAL;
		if (!op_error) {
			rma_addr += mr->offset;
			memcpy(rma_addr, src, len);
			if (eom) {
				if (rx->ep->recv_cq && has_data) {
					/* TODO: report the addr/len of the whole write */
					event = psmx2_cq_create_event(
							rx->ep->recv_cq,
							0, /* context */
							rma_addr,
							FI_REMOTE_WRITE | FI_RMA | FI_REMOTE_CQ_DATA,
							rma_len,
							args[4].u64,
							0, /* tag */
							0, /* olen */
							0);

					if (event)
						psmx2_cq_enqueue_event(rx->ep->recv_cq, event);
					else
						err = -FI_ENOMEM;
				}

				if (rx->ep->caps & FI_RMA_EVENT) {
					if (rx->ep->remote_write_cntr)
						psmx2_cntr_inc(rx->ep->remote_write_cntr, 0);

					if (mr->cntr && mr->cntr != rx->ep->remote_write_cntr)
						psmx2_cntr_inc(mr->cntr, 0);
				}
			}
		}
		if (eom || op_error) {
			rep_args[0].u32w0 = PSMX2_AM_REP_WRITE | eom;
			rep_args[0].u32w1 = op_error;
			rep_args[1].u64 = args[1].u64;
			err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
						  rep_args, 2, NULL, 0, 0,
						  NULL, NULL );
		}
		break;

	case PSMX2_AM_REQ_WRITE_LONG:
		rx = (struct psmx2_trx_ctxt *)hctx;
		rma_len = args[0].u32w1;
		rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_WRITE) :
			-FI_EINVAL;
		if (op_error) {
			rep_args[0].u32w0 = PSMX2_AM_REP_WRITE | eom;
			rep_args[0].u32w1 = op_error;
			rep_args[1].u64 = args[1].u64;
			err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
						  rep_args, 2, NULL, 0, 0,
						  NULL, NULL );
			break;
		}

		rma_addr += mr->offset;

		req = psmx2_am_request_alloc(rx);
		if (!req) {
			err = -FI_ENOMEM;
		} else {
			req->ep = rx->ep;
			req->op = args[0].u32w0;
			req->write.addr = (uint64_t)rma_addr;
			req->write.len = rma_len;
			req->write.key = key;
			req->write.context = (void *)args[1].u64;
			req->write.peer_addr = (void *)epaddr;
			req->write.data = has_data ? args[4].u64 : 0;
			req->cq_flags = FI_REMOTE_WRITE | FI_RMA |
					(has_data ? FI_REMOTE_CQ_DATA : 0),
			PSMX2_CTXT_TYPE(&req->fi_context) = PSMX2_REMOTE_WRITE_CONTEXT;
			PSMX2_CTXT_USER(&req->fi_context) = mr;
#if HAVE_PSM2_MQ_FP_MSG
			PSMX2_SET_TAG(psm2_tag, (uint64_t)req->write.context, 0,
					PSMX2_RMA_TYPE_WRITE);
			PSMX2_SET_MASK(psm2_tagsel, PSMX2_MATCH_ALL, PSMX2_RMA_TYPE_MASK);
			op_error = psm2_mq_fp_msg(rx->psm2_ep, rx->psm2_mq,
						 (psm2_epaddr_t)epaddr,
						 &psm2_tag, &psm2_tagsel, 0,
						 (void *)rma_addr, rma_len,
						 (void *)&req->fi_context, PSM2_MQ_IRECV_FP, &psm2_req);
			if (op_error) {
				rep_args[0].u32w0 = PSMX2_AM_REP_WRITE | eom;
				rep_args[0].u32w1 = op_error;
				rep_args[1].u64 = args[1].u64;
				err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
							  rep_args, 2, NULL, 0, 0,
							  NULL, NULL );
				psmx2_am_request_free(rx, req);
				break;
			}
#else
			psmx2_am_enqueue_rma(rx, req);
#endif
		}
		break;

	case PSMX2_AM_REQ_READ:
		rx = (struct psmx2_trx_ctxt *)hctx;
		rma_len = args[0].u32w1;
		rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		offset = args[4].u64;
		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) :
			-FI_EINVAL;
		if (!op_error) {
			rma_addr += mr->offset;
		} else {
			rma_addr = NULL;
			rma_len = 0;
		}

		rep_args[0].u32w0 = PSMX2_AM_REP_READ | eom;
		rep_args[0].u32w1 = op_error;
		rep_args[1].u64 = args[1].u64;
		rep_args[2].u64 = offset;
		err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
				rep_args, 3, rma_addr, rma_len, 0,
				NULL, NULL );

		if (eom && !op_error) {
			if (rx->ep->caps & FI_RMA_EVENT) {
				if (rx->ep->remote_read_cntr)
					psmx2_cntr_inc(rx->ep->remote_read_cntr, 0);
			}
		}
		break;

	case PSMX2_AM_REQ_READ_LONG:
		rx = (struct psmx2_trx_ctxt *)hctx;
		rma_len = args[0].u32w1;
		rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
		key = args[3].u64;
		mr = psmx2_mr_get(rx->domain, key);
		op_error = mr ?
			psmx2_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) :
			-FI_EINVAL;
		if (op_error) {
			rep_args[0].u32w0 = PSMX2_AM_REP_READ | eom;
			rep_args[0].u32w1 = op_error;
			rep_args[1].u64 = args[1].u64;
			rep_args[2].u64 = 0;
			err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
					rep_args, 3, NULL, 0, 0,
					NULL, NULL );
			break;
		}

		rma_addr += mr->offset;

		req = psmx2_am_request_alloc(rx);
		if (!req) {
			err = -FI_ENOMEM;
		} else {
			req->ep = rx->ep;
			req->op = args[0].u32w0;
			req->read.addr = (uint64_t)rma_addr;
			req->read.len = rma_len;
			req->read.key = key;
			req->read.context = (void *)args[1].u64;
			req->read.peer_addr = (void *)epaddr;
			PSMX2_CTXT_TYPE(&req->fi_context) = PSMX2_REMOTE_READ_CONTEXT;
			PSMX2_CTXT_USER(&req->fi_context) = mr;
#if HAVE_PSM2_MQ_FP_MSG
			PSMX2_SET_TAG(psm2_tag, (uint64_t)req->read.context, 0,
			PSMX2_RMA_TYPE_READ);
			op_error = psm2_mq_fp_msg(rx->psm2_ep, rx->psm2_mq,
				  (psm2_epaddr_t)req->read.peer_addr,
				 &psm2_tag, 0, 0,
				 (void *)req->read.addr, req->read.len,
				 (void *)&req->fi_context, PSM2_MQ_ISEND_FP, &psm2_req);
			if (op_error) {
				rep_args[0].u32w0 = PSMX2_AM_REP_READ | eom;
				rep_args[0].u32w1 = op_error;
				rep_args[1].u64 = args[1].u64;
				rep_args[2].u64 = 0;
				err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
						rep_args, 3, NULL, 0, 0,
						NULL, NULL );
				psmx2_am_request_free(rx, req);
				break;
			}
#else
			psmx2_am_enqueue_rma(rx, req);
#endif
		}
		break;

	case PSMX2_AM_REP_WRITE:
		req = (struct psmx2_am_request *)(uintptr_t)args[1].u64;
		assert(req->op == PSMX2_AM_REQ_WRITE);
		op_error = (int)args[0].u32w1;
		if (!req->error)
			req->error = op_error;
		if (eom) {
			if (req->ep->send_cq && (!req->no_event || req->error)) {
				event = psmx2_cq_create_event(
						req->ep->send_cq,
						req->write.context,
						req->write.buf,
						req->cq_flags,
						req->write.len,
						0, /* data */
						0, /* tag */
						0, /* olen */
						req->error);
				if (event)
					psmx2_cq_enqueue_event(req->ep->send_cq, event);
				else
					err = -FI_ENOMEM;
			}

			if (req->ep->write_cntr)
				psmx2_cntr_inc(req->ep->write_cntr, req->error);

			free(req->tmpbuf);
			psmx2_am_request_free(req->ep->tx, req);
		}
		break;

	case PSMX2_AM_REP_READ:
		req = (struct psmx2_am_request *)(uintptr_t)args[1].u64;
		assert(req->op == PSMX2_AM_REQ_READ || req->op == PSMX2_AM_REQ_READV);
		op_error = (int)args[0].u32w1;
		offset = args[2].u64;
		if (!req->error)
			req->error = op_error;
		if (!op_error) {
			if (req->op == PSMX2_AM_REQ_READ)
				memcpy(req->read.buf + offset, src, len);
			else
				psmx2_iov_copy(req->iov, req->read.iov_count, offset, src, len);

			req->read.len_read += len;
		}
		if (eom || req->read.len == req->read.len_read) {
			if (!eom)
				FI_INFO(&psmx2_prov, FI_LOG_EP_DATA,
					"readv: short protocol finishes after long protocol.\n");
			if (req->ep->send_cq && (!req->no_event || req->error)) {
				event = psmx2_cq_create_event(
						req->ep->send_cq,
						req->read.context,
						req->read.buf,
						req->cq_flags,
						req->read.len_read,
						0, /* data */
						0, /* tag */
						req->read.len - req->read.len_read,
						req->error);
				if (event)
					psmx2_cq_enqueue_event(req->ep->send_cq, event);
				else
					err = -FI_ENOMEM;
			}

			if (req->ep->read_cntr)
				psmx2_cntr_inc(req->ep->read_cntr, req->error);
 
			free(req->tmpbuf);
			psmx2_am_request_free(req->ep->tx, req);
		}
		break;

	default:
		err = -FI_EINVAL;
	}
	return err;
}