in src/torch_ucc.cpp [927:1001]
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& /* unused */) {
if (size_ == 1) {
outputTensor.copy_(inputTensor);
return c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
OpType::ALLTOALL_BASE,
torch_ucc_config.enable_profiling ? "ucc:alltoall" : nullptr);
}
check_device(inputTensor.device(), outputTensor.device());
initComm(inputTensor.device());
ucc_coll_args_t coll;
AlltoallWorkData* data;
if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
data = new AlltoallWorkData(0);
TORCH_CHECK(
(outputTensor.size(0) % size_ == 0) &&
(inputTensor.size(0) % size_ == 0),
"Tensor's dim 0 does not divide equally across group size");
coll.mask = 0;
coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
coll.src.info.buffer = inputTensor.data_ptr();
coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
coll.src.info.datatype = UCC_DT_UINT8;
coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
coll.dst.info.buffer = outputTensor.data_ptr();
coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
coll.dst.info.datatype = UCC_DT_UINT8;
coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
coll.flags = 0;
} else {
data = new AlltoallWorkData(size_);
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
computeLengthsAndOffsets(
outputSplitSizes,
outputTensor,
&data->recv_lengths,
&data->recv_offsets);
computeLengthsAndOffsets(
inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
coll.src.info_v.buffer = inputTensor.data_ptr();
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
coll.src.info_v.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
coll.dst.info_v.buffer = outputTensor.data_ptr();
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
coll.dst.info_v.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER;
}
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors, data->dst);
return collective_post(
OpType::ALLTOALL_BASE,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
inputTensor.device(),
outputTensors,
"ucc:alltoall");
}