in src/torch_ucc.cpp [1055:1088]
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
if (size_ == 1) {
return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
OpType::BROADCAST,
torch_ucc_config.enable_profiling ? "ucc:broadcast" : nullptr);
}
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_BCAST;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.root = opts.rootRank;
SAVE_TENSORS(tensors, data->dst);
return collective_post(
OpType::BROADCAST,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
"ucc:broadcast");
}