in prov/sockets/src/sock_atomic.c [61:250]
ssize_t sock_ep_tx_atomic(struct fid_ep *ep,
const struct fi_msg_atomic *msg,
const struct fi_ioc *comparev, void **compare_desc,
size_t compare_count, struct fi_ioc *resultv,
void **result_desc, size_t result_count, uint64_t flags)
{
ssize_t ret;
size_t i;
size_t datatype_sz;
struct sock_op tx_op;
union sock_iov tx_iov;
struct sock_conn *conn;
struct sock_tx_ctx *tx_ctx;
uint64_t total_len, src_len, dst_len, cmp_len, op_flags;
struct sock_ep *sock_ep;
struct sock_ep_attr *ep_attr;
switch (ep->fid.fclass) {
case FI_CLASS_EP:
sock_ep = container_of(ep, struct sock_ep, ep);
tx_ctx = sock_ep->attr->tx_ctx->use_shared ?
sock_ep->attr->tx_ctx->stx_ctx : sock_ep->attr->tx_ctx;
ep_attr = sock_ep->attr;
op_flags = sock_ep->tx_attr.op_flags;
break;
case FI_CLASS_TX_CTX:
tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx);
ep_attr = tx_ctx->ep_attr;
op_flags = tx_ctx->attr.op_flags;
break;
default:
SOCK_LOG_ERROR("Invalid EP type\n");
return -FI_EINVAL;
}
if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT ||
msg->rma_iov_count > SOCK_EP_MAX_IOV_LIMIT)
return -FI_EINVAL;
if (!tx_ctx->enabled)
return -FI_EOPBADSTATE;
ret = sock_ep_get_conn(ep_attr, tx_ctx, msg->addr, &conn);
if (ret)
return ret;
SOCK_EP_SET_TX_OP_FLAGS(flags);
if (flags & SOCK_USE_OP_FLAGS)
flags |= op_flags;
if (msg->op == FI_ATOMIC_READ) {
flags &= ~FI_INJECT;
}
if (flags & FI_TRIGGER) {
ret = sock_queue_atomic_op(ep, msg, comparev, compare_count,
resultv, result_count, flags,
FI_OP_ATOMIC);
if (ret != 1)
return ret;
}
src_len = cmp_len = 0;
datatype_sz = ofi_datatype_size(msg->datatype);
for (i = 0; i < compare_count; i++)
cmp_len += (comparev[i].count * datatype_sz);
if (flags & FI_INJECT) {
for (i = 0; i < msg->iov_count; i++)
src_len += (msg->msg_iov[i].count * datatype_sz);
if ((src_len + cmp_len) > SOCK_EP_MAX_INJECT_SZ)
return -FI_EINVAL;
total_len = src_len + cmp_len;
} else {
total_len = msg->iov_count * sizeof(union sock_iov);
}
total_len += (sizeof(struct sock_op_send) +
(msg->rma_iov_count * sizeof(union sock_iov)) +
(result_count * sizeof(union sock_iov)));
sock_tx_ctx_start(tx_ctx);
if (ofi_rbavail(&tx_ctx->rb) < total_len) {
ret = -FI_EAGAIN;
goto err;
}
memset(&tx_op, 0, sizeof(tx_op));
tx_op.op = SOCK_OP_ATOMIC;
tx_op.dest_iov_len = msg->rma_iov_count;
tx_op.atomic.op = msg->op;
tx_op.atomic.datatype = msg->datatype;
tx_op.atomic.res_iov_len = result_count;
tx_op.atomic.cmp_iov_len = compare_count;
if (flags & FI_INJECT) {
tx_op.src_iov_len = src_len;
tx_op.atomic.cmp_iov_len = cmp_len;
} else {
tx_op.src_iov_len = msg->iov_count;
}
sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags,
(uintptr_t) msg->context, msg->addr,
(uintptr_t) msg->msg_iov[0].addr, ep_attr, conn);
if (flags & FI_REMOTE_CQ_DATA)
sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(uint64_t));
src_len = dst_len = 0;
if (flags & FI_INJECT) {
for (i = 0; i < msg->iov_count; i++) {
sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].addr,
msg->msg_iov[i].count * datatype_sz);
src_len += (msg->msg_iov[i].count * datatype_sz);
}
for (i = 0; i < compare_count; i++) {
sock_tx_ctx_write(tx_ctx, comparev[i].addr,
comparev[i].count * datatype_sz);
dst_len += comparev[i].count * datatype_sz;
}
} else {
for (i = 0; i < msg->iov_count; i++) {
tx_iov.ioc.addr = (uintptr_t) msg->msg_iov[i].addr;
tx_iov.ioc.count = msg->msg_iov[i].count;
sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov));
src_len += (tx_iov.ioc.count * datatype_sz);
}
for (i = 0; i < compare_count; i++) {
tx_iov.ioc.addr = (uintptr_t) comparev[i].addr;
tx_iov.ioc.count = comparev[i].count;
sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov));
dst_len += (tx_iov.ioc.count * datatype_sz);
}
}
#if ENABLE_DEBUG
if ((src_len > SOCK_EP_MAX_ATOMIC_SZ) ||
(dst_len > SOCK_EP_MAX_ATOMIC_SZ)) {
SOCK_LOG_ERROR("Max atomic operation size exceeded!\n");
ret = -FI_EINVAL;
goto err;
} else if (compare_count && (dst_len != src_len)) {
SOCK_LOG_ERROR("Buffer length mismatch\n");
ret = -FI_EINVAL;
goto err;
}
#endif
dst_len = 0;
for (i = 0; i < msg->rma_iov_count; i++) {
tx_iov.ioc.addr = msg->rma_iov[i].addr;
tx_iov.ioc.key = msg->rma_iov[i].key;
tx_iov.ioc.count = msg->rma_iov[i].count;
sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov));
dst_len += (tx_iov.ioc.count * datatype_sz);
}
if (msg->iov_count && (dst_len != src_len)) {
SOCK_LOG_ERROR("Buffer length mismatch\n");
ret = -FI_EINVAL;
goto err;
} else {
src_len = dst_len;
}
dst_len = 0;
for (i = 0; i < result_count; i++) {
tx_iov.ioc.addr = (uintptr_t) resultv[i].addr;
tx_iov.ioc.count = resultv[i].count;
sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov));
dst_len += (tx_iov.ioc.count * datatype_sz);
}
#if ENABLE_DEBUG
if (result_count && (dst_len != src_len)) {
SOCK_LOG_ERROR("Buffer length mismatch\n");
ret = -FI_EINVAL;
goto err;
}
#endif
sock_tx_ctx_commit(tx_ctx);
return 0;
err:
sock_tx_ctx_abort(tx_ctx);
return ret;
}