in src/torch_ucc.cpp [416:444]
ucc_coll_req_h CommPG::recv_nb(
void* data,
ucs_memory_type_t mtype,
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask) {
ucs_status_ptr_t st;
ucp_request_param_t params;
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.datatype = ucp_dt_make_contig(size);
params.cb.recv = [](void* request,
ucs_status_t status,
const ucp_tag_recv_info_t* info,
void* user_data) {
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
};
params.memory_type = mtype;
st = ucp_tag_recv_nbx(
ucx_comm.worker, data, 1, ucp_tag, ucp_tag_mask, ¶ms);
if (UCS_PTR_IS_ERR(st)) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"failed to recv message: ", ucs_status_string(UCS_PTR_STATUS(st))));
throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st)));
}
return reinterpret_cast<ucc_coll_req_h>(st);
}