c10::intrusive_ptr ProcessGroupUCC::reduce_scatter()

in src/torch_ucc.cpp [1103:1173]


c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::reduce_scatter(
    std::vector<at::Tensor>& outputTensors,
    std::vector<std::vector<at::Tensor>>& inputTensors,
    const ReduceScatterOptions& opts) {
	TORCH_CHECK(
		(outputTensors.size() == inputTensors.size()),
		"Tensor input/output list for reduce_scatter must have same size");
	check_tensor(outputTensors);
	check_device(inputTensors[0][0].device(), outputTensors[0].device());
	initComm(inputTensors[0][0].device());
  auto data = std::make_unique<WorkData>();
	std::vector<at::Tensor> flat_input(inputTensors.size());
  for (size_t i = 0; i < inputTensors.size(); i++) {
    TORCH_CHECK(inputTensors[i].size() == inputTensors.size() * size_,
      "Tensor input list is not valid for the number of participants");
      flat_input[i] = c10d::newLikeFlat(inputTensors, i);
  }
  SAVE_TENSORS(flat_input, data->flat);
	check_tensor(flat_input);
  ucc_coll_args_t coll;
  coll.mask = 0;
  coll.flags = 0;
  coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
	coll.op = ucc_op_map.at(opts.reduceOp);

  coll.src.info.buffer = flat_input[0].data_ptr();
  coll.src.info.count = flat_input[0].numel();
  coll.src.info.datatype = ucc_dtype_map.at(flat_input[0].scalar_type());
  coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
  coll.dst.info.buffer = outputTensors[0].data_ptr();
  coll.dst.info.count = outputTensors[0].numel();
  coll.dst.info.datatype = ucc_dtype_map.at(outputTensors[0].scalar_type());
  coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());

  SAVE_TENSORS(inputTensors[0], data->src);
  SAVE_TENSORS(outputTensors, data->dst);

  auto copy_to_flat = [&] {
    bool asyncCopy = false;
    auto isize = inputTensors.size();
#ifdef USE_CUDA
    bool isCuda = inputTensors[0][0].device().is_cuda();
#endif
    for (size_t i = 0; i < isize; i++) {
      auto onumel = outputTensors[i].numel();
      for (size_t j = 0; j < inputTensors[i].size(); j++) {
        TORCH_CHECK(
          (inputTensors[i][j].numel() == onumel),
          "Tensor operand counts must be same");
#ifdef USE_CUDA
        if (isCuda) {
          c10::cuda::CUDACachingAllocator::recordStream(
            inputTensors[i][j].storage().data_ptr(), (*stream));
          asyncCopy = true;
        }
#endif
        flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
      }
    }
  };

 return collective_post(
    	OpType::REDUCE_SCATTER,
      copy_to_flat,
      []() {},
    	coll,
    	std::move(data),
    	inputTensors[0][0].device(),
    	outputTensors,
    	"ucc:reduce_scatter");
}