c10::intrusive_ptr ProcessGroupUCC::broadcast()

in src/torch_ucc.cpp [1055:1088]


c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::broadcast(
    std::vector<at::Tensor>& tensors,
    const BroadcastOptions& opts) {
  if (size_ == 1) {
      return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
                OpType::BROADCAST,
                torch_ucc_config.enable_profiling ? "ucc:broadcast" : nullptr);
  }
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());
  WorkData* data = new WorkData();

  ucc_coll_args_t coll;
  coll.mask = 0;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_BCAST;
  coll.src.info.buffer = tensor.data_ptr();
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.root = opts.rootRank;
  SAVE_TENSORS(tensors, data->dst);

  return collective_post(
      OpType::BROADCAST,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      "ucc:broadcast");
}