in src/torch_ucc.cpp [764:866]
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& /* unused */) {
if (size_ == 1) {
outputTensors[0][0].copy_(inputTensors[0]);
return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
OpType::ALLGATHER,
torch_ucc_config.enable_profiling ? "ucc:allgather" : nullptr);
}
auto& tensor = inputTensors[0];
check_device(tensor.device(), outputTensors[0][0].device());
initComm(tensor.device());
if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
AllgathervWorkData* data = new AllgathervWorkData(size_);
for (int i = 0; i < size_; i++) {
data->recv_lengths[i] = tensor.element_size() * tensor.numel();
data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
}
ucc_coll_args_t coll;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags =
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.element_size() * tensor.numel();
coll.src.info.datatype = UCC_DT_UINT8;
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info_v.buffer = nullptr;
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
coll.dst.info_v.datatype = UCC_DT_UINT8;
coll.dst.info_v.mem_type =
to_ucc_memType(outputTensors[0][0].device().type());
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors[0], data->dst);
return collective_post(
OpType::ALLGATHER,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
outputTensors[0],
"ucc:allgatherv");
} else {
WorkData* data = new WorkData();
std::vector<at::Tensor> flat_output(outputTensors.size());
for (size_t i = 0; i < outputTensors.size(); i++) {
TORCH_CHECK(outputTensors[i].size() == outputTensors.size() * size_,
"Tensor output list is not valid for the number of participants");
flat_output[i] = c10d::newLikeFlat(outputTensors, i);
}
SAVE_TENSORS(flat_output, data->flat);
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info.buffer = flat_output[0].data_ptr();
coll.dst.info.count = flat_output[0].numel();
coll.dst.info.datatype = ucc_dtype_map.at(flat_output[0].scalar_type());
coll.dst.info.mem_type =
to_ucc_memType(outputTensors[0][0].device().type());
auto copy_from_flat = [&] {
bool asyncCopy = false;
#ifdef USE_CUDA
bool isCuda = outputTensors[0][0].device().is_cuda();;
#endif
for (size_t i = 0; i < outputTensors.size(); i++) {
auto inumel = inputTensors[i].numel();
for (size_t j = 0; j < outputTensors[i].size(); j++) {
TORCH_CHECK(
(outputTensors[i][j].numel() == inumel),
"Tensor operand counts must be same");
#ifdef USE_CUDA
if (isCuda) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors[i][j].storage().data_ptr(), (*stream));
asyncCopy = true;
}
#endif
outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
}
}
};
return collective_post(
OpType::ALLGATHER,
[]() {},
copy_from_flat,
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
outputTensors[0],
"ucc:allgather");
}
}