in prov/psm/src/psmx_rma.c [74:323]
int psmx_am_rma_handler(psm_am_token_t token, psm_epaddr_t epaddr,
psm_amarg_t *args, int nargs, void *src, uint32_t len)
{
psm_amarg_t rep_args[8];
uint8_t *rma_addr;
ssize_t rma_len;
uint64_t key;
int err = 0;
int op_error = 0;
int cmd, eom, has_data;
struct psmx_am_request *req;
struct psmx_cq_event *event;
uint64_t offset;
struct psmx_fid_mr *mr;
cmd = args[0].u32w0 & PSMX_AM_OP_MASK;
eom = args[0].u32w0 & PSMX_AM_EOM;
has_data = args[0].u32w0 & PSMX_AM_DATA;
switch (cmd) {
case PSMX_AM_REQ_WRITE:
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)rma_addr, len, FI_REMOTE_WRITE) :
-FI_EINVAL;
if (!op_error) {
rma_addr += mr->offset;
memcpy(rma_addr, src, len);
if (eom) {
if (mr->domain->rma_ep->recv_cq && has_data) {
/* TODO: report the addr/len of the whole write */
event = psmx_cq_create_event(
mr->domain->rma_ep->recv_cq,
0, /* context */
rma_addr,
FI_REMOTE_WRITE | FI_RMA | FI_REMOTE_CQ_DATA,
rma_len,
args[4].u64,
0, /* tag */
0, /* olen */
0);
if (event)
psmx_cq_enqueue_event(mr->domain->rma_ep->recv_cq, event);
else
err = -FI_ENOMEM;
}
if (mr->domain->rma_ep->caps & FI_RMA_EVENT) {
if (mr->domain->rma_ep->remote_write_cntr)
psmx_cntr_inc(mr->domain->rma_ep->remote_write_cntr);
if (mr->cntr && mr->cntr != mr->domain->rma_ep->remote_write_cntr)
psmx_cntr_inc(mr->cntr);
}
}
}
if (eom || op_error) {
rep_args[0].u32w0 = PSMX_AM_REP_WRITE | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER,
rep_args, 2, NULL, 0, 0,
NULL, NULL );
}
break;
case PSMX_AM_REQ_WRITE_LONG:
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_WRITE) :
-FI_EINVAL;
if (op_error) {
rep_args[0].u32w0 = PSMX_AM_REP_WRITE | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER,
rep_args, 2, NULL, 0, 0,
NULL, NULL );
break;
}
rma_addr += mr->offset;
req = calloc(1, sizeof(*req));
if (!req) {
err = -FI_ENOMEM;
} else {
req->op = args[0].u32w0;
req->write.addr = (uint64_t)rma_addr;
req->write.len = rma_len;
req->write.key = key;
req->write.context = (void *)args[4].u64;
req->write.peer_context = (void *)args[1].u64;
req->write.peer_addr = (void *)epaddr;
req->write.data = has_data ? *(uint64_t *)src: 0;
req->cq_flags = FI_REMOTE_WRITE | FI_RMA | (has_data ? FI_REMOTE_CQ_DATA : 0),
PSMX_CTXT_TYPE(&req->fi_context) = PSMX_REMOTE_WRITE_CONTEXT;
PSMX_CTXT_USER(&req->fi_context) = mr;
psmx_am_enqueue_rma(mr->domain, req);
}
break;
case PSMX_AM_REQ_READ:
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
offset = args[4].u64;
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) :
-FI_EINVAL;
if (!op_error) {
rma_addr += mr->offset;
} else {
rma_addr = NULL;
rma_len = 0;
}
rep_args[0].u32w0 = PSMX_AM_REP_READ | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
rep_args[2].u64 = offset;
err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER,
rep_args, 3, rma_addr, rma_len, 0,
NULL, NULL );
if (eom && !op_error) {
if (mr->domain->rma_ep->caps & FI_RMA_EVENT) {
if (mr->domain->rma_ep->remote_read_cntr)
psmx_cntr_inc(mr->domain->rma_ep->remote_read_cntr);
}
}
break;
case PSMX_AM_REQ_READ_LONG:
rma_len = args[0].u32w1;
rma_addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) :
-FI_EINVAL;
if (op_error) {
rep_args[0].u32w0 = PSMX_AM_REP_READ | eom;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
rep_args[2].u64 = 0;
err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER,
rep_args, 3, NULL, 0, 0,
NULL, NULL );
break;
}
rma_addr += mr->offset;
req = calloc(1, sizeof(*req));
if (!req) {
err = -FI_ENOMEM;
} else {
req->op = args[0].u32w0;
req->read.addr = (uint64_t)rma_addr;
req->read.len = rma_len;
req->read.key = key;
req->read.context = (void *)args[4].u64;
req->read.peer_addr = (void *)epaddr;
PSMX_CTXT_TYPE(&req->fi_context) = PSMX_REMOTE_READ_CONTEXT;
PSMX_CTXT_USER(&req->fi_context) = mr;
psmx_am_enqueue_rma(mr->domain, req);
}
break;
case PSMX_AM_REP_WRITE:
req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
assert(req->op == PSMX_AM_REQ_WRITE);
op_error = (int)args[0].u32w1;
if (!req->error)
req->error = op_error;
if (eom) {
if (req->ep->send_cq && (!req->no_event || req->error)) {
event = psmx_cq_create_event(
req->ep->send_cq,
req->write.context,
req->write.buf,
req->cq_flags,
req->write.len,
0, /* data */
0, /* tag */
0, /* olen */
req->error);
if (event)
psmx_cq_enqueue_event(req->ep->send_cq, event);
else
err = -FI_ENOMEM;
}
if (req->ep->write_cntr)
psmx_cntr_inc(req->ep->write_cntr);
free(req);
}
break;
case PSMX_AM_REP_READ:
req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
assert(req->op == PSMX_AM_REQ_READ);
op_error = (int)args[0].u32w1;
offset = args[2].u64;
if (!req->error)
req->error = op_error;
if (!op_error) {
memcpy(req->read.buf + offset, src, len);
req->read.len_read += len;
}
if (eom) {
if (req->ep->send_cq && (!req->no_event || req->error)) {
event = psmx_cq_create_event(
req->ep->send_cq,
req->read.context,
req->read.buf,
req->cq_flags,
req->read.len_read,
0, /* data */
0, /* tag */
req->read.len - req->read.len_read,
req->error);
if (event)
psmx_cq_enqueue_event(req->ep->send_cq, event);
else
err = -FI_ENOMEM;
}
if (req->ep->read_cntr)
psmx_cntr_inc(req->ep->read_cntr);
free(req);
}
break;
default:
err = -FI_EINVAL;
}
return err;
}