AllGatherLoraLinearOutput CudaDevice::allGatherloraLinear()

in maga_transformer/cpp/devices/cuda_impl/CudaLoraLinear.cc [132:280]


AllGatherLoraLinearOutput CudaDevice::allGatherloraLinear(const AllGatherLoraLinearParams& params) {
    size_t overlap_comm_type = init_params_.overlap_comm_type;
    if (overlap_comm_type <= 1) {
        return DeviceBase::allGatherloraLinear(params);
    }
    const LoraLinearParams& linear_params = params.lora_linear_params;
    const auto &gemm_a = linear_params.gemm_params.A;
    const auto &gemm_b = linear_params.gemm_params.B;
    const size_t m = gemm_a.shape()[0];
    const size_t k = gemm_a.shape()[1];
    const size_t n = gemm_b.shape()[1];

    BufferPtr output = nullptr;
    if (linear_params.gemm_params.D) {
        output = linear_params.gemm_params.D;
    } else {
        output = allocateBuffer({params.output_type, {m, n}});
        linear_params.gemm_params.D = output;
    }
    size_t tp_size = params.mode == ParallelMode::FFN_TP ? init_params_.ffn_tp_size : init_params_.tp_size;
    size_t tp_rank = params.mode == ParallelMode::FFN_TP ? init_params_.ffn_tp_rank : init_params_.tp_rank;

    if ((overlap_comm_type == 2) &&
        ((!linear_params.lora_input) || linear_params.lora_input->isEmpty()) &&
        init_params_.tp_size > 1) {
        CommBuffer* cb = nullptr;
        CommBuffer* scale_cb = nullptr;

        const auto m_chunk = m / tp_size;
        if (params.mode == ParallelMode::TP) {
            cb = attn_ag_comm_buffer_.get();
            if (params.qscheme == Qint8PerToken) {
                scale_cb = attn_ag_scale_comm_buffer_.get();
            }
        } else if (params.mode == ParallelMode::FFN_TP) {
            cb = ffn_ag_comm_buffer_.get();
            if (params.qscheme == Qint8PerToken) {
                scale_cb = ffn_ag_scale_comm_buffer_.get();
            }
        } else {
            RTP_LLM_CHECK("unavailable");
        }
        Communicator* comm = cb->_comm;

        // Get communication and GEMM output chunk sizes
        const int comm_bytes = params.ag_send_buffer->sizeBytes();

        int scale_comm_bytes = 0;
        if (params.qscheme == Qint8PerToken) {
            scale_comm_bytes = reinterpret_cast<const QBuffer&>(*params.ag_send_buffer).scales().sizeBytes();
        }

        check_cuda_error(cudaEventRecord(cb->_start_compute, stream_));
        check_cuda_error(cudaStreamWaitEvent(cb->_stream_send[0], cb->_start_compute, 0));
        check_cuda_error(cudaStreamWaitEvent(cb->_stream_recv, cb->_start_compute, 0));
        for (size_t i = 0; i < cb->_stream_compute.size(); i++) {
            check_cuda_error(cudaStreamWaitEvent(cb->_stream_compute[i], cb->_start_compute, 0));
        }

        BufferPtr gemm_a = nullptr;
        if (params.qscheme == NoQuantize) {
            gemm_a = BufferPtr(new Buffer(MemoryType::MEMORY_GPU, params.ag_send_buffer->type(), {m, k}, cb->_ubuf));
        } else if (params.qscheme == Qint8PerToken) {
            BufferPtr gemm_a_kernel = BufferPtr(new Buffer(MemoryType::MEMORY_GPU, params.ag_send_buffer->type(), {m, k}, cb->_ubuf));
            BufferPtr gemm_a_sacle = BufferPtr(new Buffer(MemoryType::MEMORY_GPU, DataType::TYPE_FP32, {m}, scale_cb->_ubuf));
            gemm_a = BufferPtr(new QBuffer(std::move(gemm_a_kernel),
                                            std::move(gemm_a_sacle),
                                            std::move(BufferPtr(
                                                new Buffer(MemoryType::MEMORY_GPU,
                                                DataType::TYPE_INVALID,
                                                {0},
                                                nullptr)))));
        } else if (params.qscheme == Qfp8PerTensor) {
            BufferPtr gemm_a_kernel = BufferPtr(new Buffer(MemoryType::MEMORY_GPU, params.ag_send_buffer->type(), {m, k}, cb->_ubuf));
            BufferPtr gemm_a_sacle = BufferPtr(new Buffer(MemoryType::MEMORY_GPU, DataType::TYPE_FP32, {1}, reinterpret_cast<const QBuffer&>(linear_params.gemm_params.A).scalesData()));
            gemm_a = BufferPtr(new QBuffer(std::move(gemm_a_kernel),
                                            std::move(gemm_a_sacle),
                                            std::move(BufferPtr(
                                                new Buffer(MemoryType::MEMORY_GPU,
                                                DataType::TYPE_INVALID,
                                                {0},
                                                nullptr)))));
        } else {
            RTP_LLM_FAIL("unsupported qscheme");
        }


        for (int i = 0; i < tp_size; i++) {
            int send_chunk_id = (tp_size + tp_rank - i) % tp_size;
            int send_offset = comm_bytes * send_chunk_id;

            BufferPtr input_a_chunk = nullptr;
            if (params.qscheme == NoQuantize) {
                input_a_chunk = gemm_a->slice(send_chunk_id * m_chunk, m_chunk);
            } else if (params.qscheme == Qint8PerToken) {
                input_a_chunk = reinterpret_cast<const QBuffer*>(gemm_a.get())->qslice(send_chunk_id * m_chunk, m_chunk);
            } else if (params.qscheme == Qfp8PerTensor) {
                input_a_chunk = reinterpret_cast<const QBuffer*>(gemm_a.get())->qslicePerTensor(send_chunk_id * m_chunk, m_chunk);
            } else {
                RTP_LLM_FAIL("unsupported qscheme");
            }
            BufferPtr output_chunk = output->slice(send_chunk_id * m_chunk, m_chunk);

            auto gemm_params = GemmParams(
                *input_a_chunk,
                linear_params.gemm_params.B,
                linear_params.gemm_params.C,
                output_chunk,
                linear_params.gemm_params.compute_type,
                linear_params.gemm_params.transA,
                linear_params.gemm_params.transB,
                linear_params.gemm_params.activationType,
                linear_params.gemm_params.alpha,
                linear_params.gemm_params.beta,
                init_params_.overlap_math_sm_count,
                cb->_stream_compute[i % cb->_stream_compute.size()]
            );
            loraLinear({gemm_params});

            if (i < tp_size - 1) {
                userbuffers_send(cb->_ub_reg, send_offset, send_offset, comm_bytes, comm,
                                cb->_next_rank, cb->_stream_send[0]);
                userbuffers_recv(cb->_ub_reg, comm, cb->_prev_rank, cb->_stream_recv);

                if (params.qscheme == Qint8PerToken) {
                    userbuffers_send(scale_cb->_ub_reg, scale_comm_bytes * send_chunk_id, scale_comm_bytes * send_chunk_id, scale_comm_bytes, scale_cb->_comm,
                                    cb->_next_rank, cb->_stream_send[0]);
                    userbuffers_recv(scale_cb->_ub_reg, scale_cb->_comm, cb->_prev_rank, cb->_stream_recv);
                }
                check_cuda_error(cudaEventRecord(cb->_stop_recv, cb->_stream_recv));
                check_cuda_error(cudaStreamWaitEvent(cb->_stream_send[0], cb->_stop_recv, 0));
                check_cuda_error(
                    cudaStreamWaitEvent(cb->_stream_compute[(i + 1) % cb->_stream_compute.size()], cb->_stop_recv, 0));
            }
        }

        for (size_t i = 0; i < cb->_stream_compute.size(); i++) {
            check_cuda_error(cudaEventRecord(cb->_stop_compute, cb->_stream_compute[i]));
            check_cuda_error(cudaStreamWaitEvent(stream_, cb->_stop_compute, 0));
        }
        check_cuda_error(cudaEventRecord(cb->_stop_send, cb->_stream_send[0]));
        check_cuda_error(cudaStreamWaitEvent(stream_, cb->_stop_send, 0));
        check_cuda_error(cudaEventRecord(cb->_stop_recv, cb->_stream_recv));
        check_cuda_error(cudaStreamWaitEvent(stream_, cb->_stop_recv, 0));

        return AllGatherLoraLinearOutput({std::move(output), std::move(gemm_a)});
    }
    return DeviceBase::allGatherloraLinear(params);
}