in src/torch_ucc.cpp [1103:1173]
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(
(outputTensors.size() == inputTensors.size()),
"Tensor input/output list for reduce_scatter must have same size");
check_tensor(outputTensors);
check_device(inputTensors[0][0].device(), outputTensors[0].device());
initComm(inputTensors[0][0].device());
auto data = std::make_unique<WorkData>();
std::vector<at::Tensor> flat_input(inputTensors.size());
for (size_t i = 0; i < inputTensors.size(); i++) {
TORCH_CHECK(inputTensors[i].size() == inputTensors.size() * size_,
"Tensor input list is not valid for the number of participants");
flat_input[i] = c10d::newLikeFlat(inputTensors, i);
}
SAVE_TENSORS(flat_input, data->flat);
check_tensor(flat_input);
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
coll.op = ucc_op_map.at(opts.reduceOp);
coll.src.info.buffer = flat_input[0].data_ptr();
coll.src.info.count = flat_input[0].numel();
coll.src.info.datatype = ucc_dtype_map.at(flat_input[0].scalar_type());
coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
coll.dst.info.buffer = outputTensors[0].data_ptr();
coll.dst.info.count = outputTensors[0].numel();
coll.dst.info.datatype = ucc_dtype_map.at(outputTensors[0].scalar_type());
coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());
SAVE_TENSORS(inputTensors[0], data->src);
SAVE_TENSORS(outputTensors, data->dst);
auto copy_to_flat = [&] {
bool asyncCopy = false;
auto isize = inputTensors.size();
#ifdef USE_CUDA
bool isCuda = inputTensors[0][0].device().is_cuda();
#endif
for (size_t i = 0; i < isize; i++) {
auto onumel = outputTensors[i].numel();
for (size_t j = 0; j < inputTensors[i].size(); j++) {
TORCH_CHECK(
(inputTensors[i][j].numel() == onumel),
"Tensor operand counts must be same");
#ifdef USE_CUDA
if (isCuda) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors[i][j].storage().data_ptr(), (*stream));
asyncCopy = true;
}
#endif
flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
}
}
};
return collective_post(
OpType::REDUCE_SCATTER,
copy_to_flat,
[]() {},
coll,
std::move(data),
inputTensors[0][0].device(),
outputTensors,
"ucc:reduce_scatter");
}