in src/torch_ucc.cpp [875:911]
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
if (size_ == 1) {
return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
OpType::ALLREDUCE,
torch_ucc_config.enable_profiling ? "ucc:allreduce" : nullptr);
}
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
coll.op = ucc_op_map.at(opts.reduceOp);
coll.src.info.buffer = nullptr;
coll.src.info.count = tensor.numel();
coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info.buffer = tensor.data_ptr();
coll.dst.info.count = tensor.numel();
coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
SAVE_TENSORS(tensors, data->dst);
return collective_post(
OpType::ALLREDUCE,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
"ucc:allreduce");
}