in prov/psm/src/psmx_atomic.c [374:594]
int psmx_am_atomic_handler(psm_am_token_t token, psm_epaddr_t epaddr,
psm_amarg_t *args, int nargs, void *src,
uint32_t len)
{
psm_amarg_t rep_args[8];
int count;
uint8_t *addr;
uint64_t key;
int datatype, op;
int err = 0;
int op_error = 0;
struct psmx_am_request *req;
struct psmx_cq_event *event;
struct psmx_fid_mr *mr;
struct psmx_fid_ep *target_ep;
struct psmx_fid_cntr *cntr = NULL;
struct psmx_fid_cntr *mr_cntr = NULL;
void *tmp_buf;
switch (args[0].u32w0 & PSMX_AM_OP_MASK) {
case PSMX_AM_REQ_ATOMIC_WRITE:
count = args[0].u32w1;
addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
datatype = args[4].u32w0;
op = args[4].u32w1;
assert(len == ofi_datatype_size(datatype) * count);
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) :
-FI_EINVAL;
if (!op_error) {
addr += mr->offset;
psmx_atomic_do_write(addr, src, datatype, op, count);
target_ep = mr->domain->atomics_ep;
if (target_ep->caps & FI_RMA_EVENT) {
cntr = target_ep->remote_write_cntr;
mr_cntr = mr->cntr;
if (cntr)
psmx_cntr_inc(cntr);
if (mr_cntr && mr_cntr != cntr)
psmx_cntr_inc(mr_cntr);
}
}
rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_WRITE;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
rep_args, 2, NULL, 0, 0,
NULL, NULL );
break;
case PSMX_AM_REQ_ATOMIC_READWRITE:
count = args[0].u32w1;
addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
datatype = args[4].u32w0;
op = args[4].u32w1;
if (op == FI_ATOMIC_READ)
len = ofi_datatype_size(datatype) * count;
assert(len == ofi_datatype_size(datatype) * count);
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) :
-FI_EINVAL;
if (!op_error) {
addr += mr->offset;
tmp_buf = malloc(len);
if (tmp_buf)
psmx_atomic_do_readwrite(addr, src, tmp_buf,
datatype, op, count);
else
op_error = -FI_ENOMEM;
target_ep = mr->domain->atomics_ep;
if (target_ep->caps & FI_RMA_EVENT) {
if (op == FI_ATOMIC_READ) {
cntr = target_ep->remote_read_cntr;
} else {
cntr = target_ep->remote_write_cntr;
mr_cntr = mr->cntr;
}
if (cntr)
psmx_cntr_inc(cntr);
if (mr_cntr && mr_cntr != cntr)
psmx_cntr_inc(mr_cntr);
}
} else {
tmp_buf = NULL;
}
rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
rep_args, 2, tmp_buf, (tmp_buf?len:0), 0,
psmx_am_atomic_completion, tmp_buf );
break;
case PSMX_AM_REQ_ATOMIC_COMPWRITE:
count = args[0].u32w1;
addr = (uint8_t *)(uintptr_t)args[2].u64;
key = args[3].u64;
datatype = args[4].u32w0;
op = args[4].u32w1;
len /= 2;
assert(len == ofi_datatype_size(datatype) * count);
mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
op_error = mr ?
psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) :
-FI_EINVAL;
if (!op_error) {
addr += mr->offset;
tmp_buf = malloc(len);
if (tmp_buf)
psmx_atomic_do_compwrite(addr, src, (uint8_t *)src + len,
tmp_buf, datatype, op, count);
else
op_error = -FI_ENOMEM;
target_ep = mr->domain->atomics_ep;
if (target_ep->caps & FI_RMA_EVENT) {
cntr = target_ep->remote_write_cntr;
mr_cntr = mr->cntr;
if (cntr)
psmx_cntr_inc(cntr);
if (mr_cntr && mr_cntr != cntr)
psmx_cntr_inc(mr_cntr);
}
} else {
tmp_buf = NULL;
}
rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE;
rep_args[0].u32w1 = op_error;
rep_args[1].u64 = args[1].u64;
err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
rep_args, 2, tmp_buf, (tmp_buf?len:0), 0,
psmx_am_atomic_completion, tmp_buf );
break;
case PSMX_AM_REP_ATOMIC_WRITE:
req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
op_error = (int)args[0].u32w1;
assert(req->op == PSMX_AM_REQ_ATOMIC_WRITE);
if (req->ep->send_cq && (!req->no_event || op_error)) {
event = psmx_cq_create_event(
req->ep->send_cq,
req->atomic.context,
req->atomic.buf,
req->cq_flags,
req->atomic.len,
0, /* data */
0, /* tag */
0, /* olen */
op_error);
if (event)
psmx_cq_enqueue_event(req->ep->send_cq, event);
else
err = -FI_ENOMEM;
}
if (req->ep->write_cntr)
psmx_cntr_inc(req->ep->write_cntr);
free(req);
break;
case PSMX_AM_REP_ATOMIC_READWRITE:
case PSMX_AM_REP_ATOMIC_COMPWRITE:
req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
op_error = (int)args[0].u32w1;
assert(op_error || req->atomic.len == len);
if (!op_error)
memcpy(req->atomic.result, src, len);
if (req->ep->send_cq && (!req->no_event || op_error)) {
event = psmx_cq_create_event(
req->ep->send_cq,
req->atomic.context,
req->atomic.buf,
req->cq_flags,
req->atomic.len,
0, /* data */
0, /* tag */
0, /* olen */
op_error);
if (event)
psmx_cq_enqueue_event(req->ep->send_cq, event);
else
err = -FI_ENOMEM;
}
if (req->ep->read_cntr)
psmx_cntr_inc(req->ep->read_cntr);
free(req);
break;
default:
err = -FI_EINVAL;
}
return err;
}