in maga_transformer/cpp/devices/cuda_impl/CudaOps.cc [446:584]
AllToAllOutput CudaDevice::allToAll(const AllToAllParams& params) {
RTP_LLM_CHECK_WITH_INFO(params.mode == ParallelMode::DP_AND_TP,
"all to all just support ParallelMode::DP_AND_TP but got [%d]",
params.mode);
auto& nccl_param = dp_tp_nccl_param_;
const auto world_size = nccl_param.world_size_;
assert(params.buffers.size() > 0);
if (world_size < 2) {
return {{params.buffers}};
}
auto stream = (params.overlapped && init_params_.enable_comm_overlap) ? communication_stream_ : stream_;
const size_t dims = params.buffers[0]->dim();
const auto batch_size = params.buffers[0]->shape()[0];
vector<BufferPtr> byte_buffers;
RTP_LLM_CHECK_WITH_INFO(dims == 2 || dims == 1,
"alltoall just support dims 2 or 1 but got [%s] ",
params.buffers[0]->debugString().c_str());
size_t dim1_size = 0;
vector<size_t> dim1_split_size;
for (const auto& buffer : params.buffers) {
RTP_LLM_CHECK_WITH_INFO(
buffer->dim() == dims && buffer->shape()[0] == batch_size,
"alltoall all input buffer dims must be consist with dims [%d] and batch_size [%d] but got [%s]",
dims,
batch_size,
buffer->debugString().c_str());
vector<size_t> new_shape = buffer->shape();
if (new_shape.size() < 2) {
new_shape.push_back(1);
}
assert(new_shape.size() == 2);
new_shape[1] *= getTypeSize(buffer->type());
dim1_size += new_shape[1];
dim1_split_size.push_back(new_shape[1]);
byte_buffers.emplace_back(
std::make_shared<Buffer>(MemoryType::MEMORY_GPU, DataType::TYPE_BYTES, new_shape, buffer->data()));
}
BufferPtr input_buffer;
if (byte_buffers.size() < 2) {
input_buffer = byte_buffers[0];
} else {
input_buffer = allocateBuffer({DataType::TYPE_BYTES, {batch_size, dim1_size}});
if (batch_size > 0) {
vector<torch::Tensor> input_tensors;
for (const auto& buffer : byte_buffers) {
input_tensors.emplace_back(Buffer2torchTensor(buffer, false));
}
torch::Tensor packed_tensor = Buffer2torchTensor(input_buffer, false);
torch::cat_out(packed_tensor, input_tensors, 1);
}
}
if (stream == communication_stream_) {
// NOTE: before starting communication, we need to make sure that the previous computation
// has been finished. Otherwise, the communication may overlap with the computation.
// We use cuda event to ensure the computation on main stream has been finished.
cudaEvent_t event;
check_cuda_error(cudaEventCreate(&event));
check_cuda_error(cudaEventRecord(event, stream_));
check_cuda_error(cudaStreamWaitEvent(communication_stream_, event, 0));
check_cuda_error(cudaEventDestroy(event));
}
BufferPtr output;
if (params.input_split_sizes.size() || params.output_split_sizes.size()) {
RTP_LLM_CHECK_WITH_INFO(
params.input_split_sizes.empty()
|| (params.input_split_sizes.size() == world_size
&& std::accumulate(params.input_split_sizes.begin(), params.input_split_sizes.end(), 0)
== batch_size),
"alltoall input_split_sizes is not valid");
if (params.output_split_sizes.empty()) {
output = allocateBufferLike(*input_buffer);
} else {
RTP_LLM_CHECK_WITH_INFO(params.output_split_sizes.size() == world_size,
"alltoall output_split_sizes is not valid");
size_t output_batch_size =
std::accumulate(params.output_split_sizes.begin(), params.output_split_sizes.end(), (size_t)0);
auto new_shape = input_buffer->shape();
new_shape[0] = output_batch_size;
output = allocateBuffer({input_buffer->type(), new_shape});
}
std::vector<size_t> send_lengths(world_size);
std::vector<size_t> recv_lengths(world_size);
std::vector<size_t> send_offsets(world_size);
std::vector<size_t> recv_offsets(world_size);
computeLengthsAndOffsets(params.input_split_sizes, *input_buffer, &send_lengths, &send_offsets);
computeLengthsAndOffsets(params.output_split_sizes, *output, &recv_lengths, &recv_offsets);
all2all_single_unequal_split(input_buffer->data(),
send_lengths.data(),
send_offsets.data(),
output->data(),
recv_lengths.data(),
recv_offsets.data(),
getTypeSize(output->type()),
getNcclDataType(output->type()),
nccl_param.nccl_comm_,
stream);
} else {
RTP_LLM_CHECK_WITH_INFO(input_buffer->shape()[0] % world_size == 0,
"all2all_single_equal_split batch size [%d] must divide world size [%d]",
input_buffer->shape()[0],
world_size);
output = allocateBufferLike(*input_buffer);
all2all_single_equal_split(
input_buffer->data(), output->data(), output->sizeBytes(), nccl_param.nccl_comm_, stream);
}
AllToAllOutput all_to_all_output;
if (byte_buffers.size() < 2) {
vector<size_t> new_shape = output->shape();
new_shape[1] /= getTypeSize(params.buffers[0]->type());
output->updateTypeAndShape(params.buffers[0]->type(), new_shape);
all_to_all_output = {{output}};
} else {
vector<BufferPtr> outputs;
size_t output_batch_size = output->shape()[0];
if (output_batch_size == 0) {
for (int i = 0; i < dim1_split_size.size(); ++i) {
vector<size_t> new_shape = params.buffers[i]->shape();
new_shape[0] = 0;
outputs.emplace_back(
std::make_shared<Buffer>(MemoryType::MEMORY_GPU, params.buffers[i]->type(), new_shape, nullptr));
}
} else {
outputs = split({*output, dim1_split_size, 1, params.overlapped}).outputs;
for (int i = 0; i < dim1_split_size.size(); ++i) {
vector<size_t> new_shape = outputs[i]->shape();
assert(new_shape[0] == output_batch_size);
new_shape[1] /= getTypeSize(params.buffers[i]->type());
assert(new_shape[1] == params.buffers[i]->shape()[1]);
outputs[i]->updateTypeAndShape(params.buffers[i]->type(), new_shape);
}
}
all_to_all_output = {{outputs}, input_buffer, output};
}
if (params.overlapped) {
all_to_all_output.comm_barrier_hook = createCommHook();
}
return all_to_all_output;
}