c10::intrusive_ptr ProcessGroupUCC::allreduce()

in src/torch_ucc.cpp [875:911]


c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allreduce(
    std::vector<at::Tensor>& tensors,
    const AllreduceOptions& opts) {
  if (size_ == 1) {
      return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
                OpType::ALLREDUCE,
                torch_ucc_config.enable_profiling ? "ucc:allreduce" : nullptr);
  }
  check_tensor(tensors);
  auto& tensor = tensors[0];
  initComm(tensor.device());
  WorkData* data = new WorkData();

  ucc_coll_args_t coll;
  coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
  coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
  coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
  coll.op = ucc_op_map.at(opts.reduceOp);
  coll.src.info.buffer = nullptr;
  coll.src.info.count = tensor.numel();
  coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
  coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
  coll.dst.info.buffer = tensor.data_ptr();
  coll.dst.info.count = tensor.numel();
  coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
  coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
  SAVE_TENSORS(tensors, data->dst);
  return collective_post(
      OpType::ALLREDUCE,
      []() {},
      []() {},
      coll,
      std::unique_ptr<WorkData>(data),
      tensor.device(),
      tensors,
      "ucc:allreduce");
}