AllToAllOutput CudaDevice::allToAll()

in maga_transformer/cpp/devices/cuda_impl/CudaOps.cc [446:584]


AllToAllOutput CudaDevice::allToAll(const AllToAllParams& params) {
    RTP_LLM_CHECK_WITH_INFO(params.mode == ParallelMode::DP_AND_TP,
                       "all to all just support ParallelMode::DP_AND_TP but got [%d]",
                       params.mode);
    auto&      nccl_param = dp_tp_nccl_param_;
    const auto world_size = nccl_param.world_size_;
    assert(params.buffers.size() > 0);
    if (world_size < 2) {
        return {{params.buffers}};
    }
    auto         stream     = (params.overlapped && init_params_.enable_comm_overlap) ? communication_stream_ : stream_;
    const size_t dims       = params.buffers[0]->dim();
    const auto   batch_size = params.buffers[0]->shape()[0];
    vector<BufferPtr> byte_buffers;
    RTP_LLM_CHECK_WITH_INFO(dims == 2 || dims == 1,
                       "alltoall just support dims 2 or 1 but got [%s] ",
                       params.buffers[0]->debugString().c_str());
    size_t         dim1_size = 0;
    vector<size_t> dim1_split_size;
    for (const auto& buffer : params.buffers) {
        RTP_LLM_CHECK_WITH_INFO(
            buffer->dim() == dims && buffer->shape()[0] == batch_size,
            "alltoall all input buffer dims must be consist with dims [%d] and batch_size [%d] but got [%s]",
            dims,
            batch_size,
            buffer->debugString().c_str());
        vector<size_t> new_shape = buffer->shape();
        if (new_shape.size() < 2) {
            new_shape.push_back(1);
        }
        assert(new_shape.size() == 2);
        new_shape[1] *= getTypeSize(buffer->type());
        dim1_size += new_shape[1];
        dim1_split_size.push_back(new_shape[1]);
        byte_buffers.emplace_back(
            std::make_shared<Buffer>(MemoryType::MEMORY_GPU, DataType::TYPE_BYTES, new_shape, buffer->data()));
    }
    BufferPtr input_buffer;
    if (byte_buffers.size() < 2) {
        input_buffer = byte_buffers[0];
    } else {
        input_buffer = allocateBuffer({DataType::TYPE_BYTES, {batch_size, dim1_size}});
        if (batch_size > 0) {
            vector<torch::Tensor> input_tensors;
            for (const auto& buffer : byte_buffers) {
                input_tensors.emplace_back(Buffer2torchTensor(buffer, false));
            }
            torch::Tensor packed_tensor = Buffer2torchTensor(input_buffer, false);
            torch::cat_out(packed_tensor, input_tensors, 1);
        }
    }
    if (stream == communication_stream_) {
        // NOTE: before starting communication, we need to make sure that the previous computation
        // has been finished. Otherwise, the communication may overlap with the computation.
        // We use cuda event to ensure the computation on main stream has been finished.
        cudaEvent_t event;
        check_cuda_error(cudaEventCreate(&event));
        check_cuda_error(cudaEventRecord(event, stream_));
        check_cuda_error(cudaStreamWaitEvent(communication_stream_, event, 0));
        check_cuda_error(cudaEventDestroy(event));
    }
    BufferPtr output;
    if (params.input_split_sizes.size() || params.output_split_sizes.size()) {
        RTP_LLM_CHECK_WITH_INFO(
            params.input_split_sizes.empty()
                || (params.input_split_sizes.size() == world_size
                    && std::accumulate(params.input_split_sizes.begin(), params.input_split_sizes.end(), 0)
                           == batch_size),
            "alltoall input_split_sizes is not valid");

        if (params.output_split_sizes.empty()) {
            output = allocateBufferLike(*input_buffer);
        } else {
            RTP_LLM_CHECK_WITH_INFO(params.output_split_sizes.size() == world_size,
                               "alltoall output_split_sizes is not valid");
            size_t output_batch_size =
                std::accumulate(params.output_split_sizes.begin(), params.output_split_sizes.end(), (size_t)0);
            auto new_shape = input_buffer->shape();
            new_shape[0]   = output_batch_size;
            output         = allocateBuffer({input_buffer->type(), new_shape});
        }
        std::vector<size_t> send_lengths(world_size);
        std::vector<size_t> recv_lengths(world_size);
        std::vector<size_t> send_offsets(world_size);
        std::vector<size_t> recv_offsets(world_size);
        computeLengthsAndOffsets(params.input_split_sizes, *input_buffer, &send_lengths, &send_offsets);
        computeLengthsAndOffsets(params.output_split_sizes, *output, &recv_lengths, &recv_offsets);
        all2all_single_unequal_split(input_buffer->data(),
                                     send_lengths.data(),
                                     send_offsets.data(),
                                     output->data(),
                                     recv_lengths.data(),
                                     recv_offsets.data(),
                                     getTypeSize(output->type()),
                                     getNcclDataType(output->type()),
                                     nccl_param.nccl_comm_,
                                     stream);
    } else {
        RTP_LLM_CHECK_WITH_INFO(input_buffer->shape()[0] % world_size == 0,
                           "all2all_single_equal_split batch size [%d] must divide world size [%d]",
                           input_buffer->shape()[0],
                           world_size);
        output = allocateBufferLike(*input_buffer);
        all2all_single_equal_split(
            input_buffer->data(), output->data(), output->sizeBytes(), nccl_param.nccl_comm_, stream);
    }
    AllToAllOutput all_to_all_output;
    if (byte_buffers.size() < 2) {
        vector<size_t> new_shape = output->shape();
        new_shape[1] /= getTypeSize(params.buffers[0]->type());
        output->updateTypeAndShape(params.buffers[0]->type(), new_shape);
        all_to_all_output = {{output}};
    } else {
        vector<BufferPtr> outputs;
        size_t            output_batch_size = output->shape()[0];
        if (output_batch_size == 0) {
            for (int i = 0; i < dim1_split_size.size(); ++i) {
                vector<size_t> new_shape = params.buffers[i]->shape();
                new_shape[0]             = 0;
                outputs.emplace_back(
                    std::make_shared<Buffer>(MemoryType::MEMORY_GPU, params.buffers[i]->type(), new_shape, nullptr));
            }
        } else {
            outputs = split({*output, dim1_split_size, 1, params.overlapped}).outputs;
            for (int i = 0; i < dim1_split_size.size(); ++i) {
                vector<size_t> new_shape = outputs[i]->shape();
                assert(new_shape[0] == output_batch_size);
                new_shape[1] /= getTypeSize(params.buffers[i]->type());
                assert(new_shape[1] == params.buffers[i]->shape()[1]);
                outputs[i]->updateTypeAndShape(params.buffers[i]->type(), new_shape);
            }
        }
        all_to_all_output = {{outputs}, input_buffer, output};
    }
    if (params.overlapped) {
        all_to_all_output.comm_barrier_hook = createCommHook();
    }
    return all_to_all_output;
}