in prov/psm2/src/psmx2_rma.c [103:418]
int psmx2_am_rma_handler(psm2_am_token_t token, psm2_amarg_t *args,
int nargs, void *src, uint32_t len,
void *hctx)
{
psm2_amarg_t rep_args[8];
uint8_t *rma_addr;
ssize_t rma_len;
uint64_t key;
int err = 0;
int op_error = 0;
int cmd, eom, has_data;
struct psmx2_am_request *req;
struct psmx2_cq_event *event;
uint64_t offset;
struct psmx2_fid_mr *mr;
psm2_epaddr_t epaddr;
struct psmx2_trx_ctxt *rx;
#if HAVE_PSM2_MQ_FP_MSG
psm2_mq_req_t psm2_req;
psm2_mq_tag_t psm2_tag, psm2_tagsel;
#endif
psm2_am_get_source(token, &epaddr);
cmd = PSMX2_AM_GET_OP(args[0].u32w0);
eom = args[0].u32w0 & PSMX2_AM_EOM;
has_data = args[0].u32w0 & PSMX2_AM_DATA;
switch (cmd) {
case PSMX2_AM_REQ_WRITE:
rx = (struct psmx2_trx_ctxt *)hctx;
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
mr = psmx2_mr_get(rx->domain, key);
op_error = mr ?
psmx2_mr_validate(mr, (uint64_t)rma_addr, len, FI_REMOTE_WRITE) :
-FI_EINVAL;
if (!op_error) {
rma_addr += mr->offset;
memcpy(rma_addr, src, len);
if (eom) {
if (rx->ep->recv_cq && has_data) {
/* TODO: report the addr/len of the whole write */
event = psmx2_cq_create_event(
rx->ep->recv_cq,
0, /* context */
rma_addr,
FI_REMOTE_WRITE | FI_RMA | FI_REMOTE_CQ_DATA,
rma_len,
args[4].u64,
0, /* tag */
0, /* olen */
0);
if (event)
psmx2_cq_enqueue_event(rx->ep->recv_cq, event);
else
err = -FI_ENOMEM;
}
if (rx->ep->caps & FI_RMA_EVENT) {
if (rx->ep->remote_write_cntr)
psmx2_cntr_inc(rx->ep->remote_write_cntr, 0);
if (mr->cntr && mr->cntr != rx->ep->remote_write_cntr)
psmx2_cntr_inc(mr->cntr, 0);
}
}
}
if (eom || op_error) {
rep_args[0].u32w0 = PSMX2_AM_REP_WRITE | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
rep_args, 2, NULL, 0, 0,
NULL, NULL );
}
break;
case PSMX2_AM_REQ_WRITE_LONG:
rx = (struct psmx2_trx_ctxt *)hctx;
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
mr = psmx2_mr_get(rx->domain, key);
op_error = mr ?
psmx2_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_WRITE) :
-FI_EINVAL;
if (op_error) {
rep_args[0].u32w0 = PSMX2_AM_REP_WRITE | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
rep_args, 2, NULL, 0, 0,
NULL, NULL );
break;
}
rma_addr += mr->offset;
req = psmx2_am_request_alloc(rx);
if (!req) {
err = -FI_ENOMEM;
} else {
req->ep = rx->ep;
req->op = args[0].u32w0;
req->write.addr = (uint64_t)rma_addr;
req->write.len = rma_len;
req->write.key = key;
req->write.context = (void *)args[1].u64;
req->write.peer_addr = (void *)epaddr;
req->write.data = has_data ? args[4].u64 : 0;
req->cq_flags = FI_REMOTE_WRITE | FI_RMA |
(has_data ? FI_REMOTE_CQ_DATA : 0),
PSMX2_CTXT_TYPE(&req->fi_context) = PSMX2_REMOTE_WRITE_CONTEXT;
PSMX2_CTXT_USER(&req->fi_context) = mr;
#if HAVE_PSM2_MQ_FP_MSG
PSMX2_SET_TAG(psm2_tag, (uint64_t)req->write.context, 0,
PSMX2_RMA_TYPE_WRITE);
PSMX2_SET_MASK(psm2_tagsel, PSMX2_MATCH_ALL, PSMX2_RMA_TYPE_MASK);
op_error = psm2_mq_fp_msg(rx->psm2_ep, rx->psm2_mq,
(psm2_epaddr_t)epaddr,
&psm2_tag, &psm2_tagsel, 0,
(void *)rma_addr, rma_len,
(void *)&req->fi_context, PSM2_MQ_IRECV_FP, &psm2_req);
if (op_error) {
rep_args[0].u32w0 = PSMX2_AM_REP_WRITE | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
rep_args, 2, NULL, 0, 0,
NULL, NULL );
psmx2_am_request_free(rx, req);
break;
}
#else
psmx2_am_enqueue_rma(rx, req);
#endif
}
break;
case PSMX2_AM_REQ_READ:
rx = (struct psmx2_trx_ctxt *)hctx;
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
offset = args[4].u64;
mr = psmx2_mr_get(rx->domain, key);
op_error = mr ?
psmx2_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) :
-FI_EINVAL;
if (!op_error) {
rma_addr += mr->offset;
} else {
rma_addr = NULL;
rma_len = 0;
}
rep_args[0].u32w0 = PSMX2_AM_REP_READ | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
rep_args[2].u64 = offset;
err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
rep_args, 3, rma_addr, rma_len, 0,
NULL, NULL );
if (eom && !op_error) {
if (rx->ep->caps & FI_RMA_EVENT) {
if (rx->ep->remote_read_cntr)
psmx2_cntr_inc(rx->ep->remote_read_cntr, 0);
}
}
break;
case PSMX2_AM_REQ_READ_LONG:
rx = (struct psmx2_trx_ctxt *)hctx;
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
mr = psmx2_mr_get(rx->domain, key);
op_error = mr ?
psmx2_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) :
-FI_EINVAL;
if (op_error) {
rep_args[0].u32w0 = PSMX2_AM_REP_READ | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
rep_args[2].u64 = 0;
err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
rep_args, 3, NULL, 0, 0,
NULL, NULL );
break;
}
rma_addr += mr->offset;
req = psmx2_am_request_alloc(rx);
if (!req) {
err = -FI_ENOMEM;
} else {
req->ep = rx->ep;
req->op = args[0].u32w0;
req->read.addr = (uint64_t)rma_addr;
req->read.len = rma_len;
req->read.key = key;
req->read.context = (void *)args[1].u64;
req->read.peer_addr = (void *)epaddr;
PSMX2_CTXT_TYPE(&req->fi_context) = PSMX2_REMOTE_READ_CONTEXT;
PSMX2_CTXT_USER(&req->fi_context) = mr;
#if HAVE_PSM2_MQ_FP_MSG
PSMX2_SET_TAG(psm2_tag, (uint64_t)req->read.context, 0,
PSMX2_RMA_TYPE_READ);
op_error = psm2_mq_fp_msg(rx->psm2_ep, rx->psm2_mq,
(psm2_epaddr_t)req->read.peer_addr,
&psm2_tag, 0, 0,
(void *)req->read.addr, req->read.len,
(void *)&req->fi_context, PSM2_MQ_ISEND_FP, &psm2_req);
if (op_error) {
rep_args[0].u32w0 = PSMX2_AM_REP_READ | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
rep_args[2].u64 = 0;
err = psm2_am_reply_short(token, PSMX2_AM_RMA_HANDLER,
rep_args, 3, NULL, 0, 0,
NULL, NULL );
psmx2_am_request_free(rx, req);
break;
}
#else
psmx2_am_enqueue_rma(rx, req);
#endif
}
break;
case PSMX2_AM_REP_WRITE:
req = (struct psmx2_am_request *)(uintptr_t)args[1].u64;
assert(req->op == PSMX2_AM_REQ_WRITE);
op_error = (int)args[0].u32w1;
if (!req->error)
req->error = op_error;
if (eom) {
if (req->ep->send_cq && (!req->no_event || req->error)) {
event = psmx2_cq_create_event(
req->ep->send_cq,
req->write.context,
req->write.buf,
req->cq_flags,
req->write.len,
0, /* data */
0, /* tag */
0, /* olen */
req->error);
if (event)
psmx2_cq_enqueue_event(req->ep->send_cq, event);
else
err = -FI_ENOMEM;
}
if (req->ep->write_cntr)
psmx2_cntr_inc(req->ep->write_cntr, req->error);
free(req->tmpbuf);
psmx2_am_request_free(req->ep->tx, req);
}
break;
case PSMX2_AM_REP_READ:
req = (struct psmx2_am_request *)(uintptr_t)args[1].u64;
assert(req->op == PSMX2_AM_REQ_READ || req->op == PSMX2_AM_REQ_READV);
op_error = (int)args[0].u32w1;
offset = args[2].u64;
if (!req->error)
req->error = op_error;
if (!op_error) {
if (req->op == PSMX2_AM_REQ_READ)
memcpy(req->read.buf + offset, src, len);
else
psmx2_iov_copy(req->iov, req->read.iov_count, offset, src, len);
req->read.len_read += len;
}
if (eom || req->read.len == req->read.len_read) {
if (!eom)
FI_INFO(&psmx2_prov, FI_LOG_EP_DATA,
"readv: short protocol finishes after long protocol.\n");
if (req->ep->send_cq && (!req->no_event || req->error)) {
event = psmx2_cq_create_event(
req->ep->send_cq,
req->read.context,
req->read.buf,
req->cq_flags,
req->read.len_read,
0, /* data */
0, /* tag */
req->read.len - req->read.len_read,
req->error);
if (event)
psmx2_cq_enqueue_event(req->ep->send_cq, event);
else
err = -FI_ENOMEM;
}
if (req->ep->read_cntr)
psmx2_cntr_inc(req->ep->read_cntr, req->error);
free(req->tmpbuf);
psmx2_am_request_free(req->ep->tx, req);
}
break;
default:
err = -FI_EINVAL;
}
return err;
}