in src/torch_ucc.cpp [390:414]
ucc_coll_req_h CommPG::send_nb(
ucp_ep_h ep,
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag) {
ucs_status_ptr_t st;
ucp_request_param_t params;
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.datatype = ucp_dt_make_contig(size);
params.memory_type = mtype;
params.cb.send = [](void* request, ucs_status_t status, void* user_data) {
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
};
st = ucp_tag_send_nbx(ep, data, 1, ucp_tag, ¶ms);
if (UCS_PTR_IS_ERR(st)) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"failed to send message: ", ucs_status_string(UCS_PTR_STATUS(st))));
throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st)));
}
return reinterpret_cast<ucc_coll_req_h>(st);
}