AttentionLayerOutput DeviceBase::mlaAttentionLayer()

in maga_transformer/cpp/devices/base_impl/MlaAttentionLayer.cc [11:160]


AttentionLayerOutput DeviceBase::mlaAttentionLayer(const AttentionLayerParams& params) {
    DevicePerfWrapper wrapper(this, "mla_layer_%d", params.layer_id);
    const auto& input            = params.input;
    const auto& input_lengths    = *params.common.input_lengths;

//    const auto& output_weight = params.weights.output_weight;

    const auto generate_batch_size = params.common.decoder_batch_size;
    const auto context_batch_size  = params.common.context_batch_size;
    const auto context_token_num   = params.common.context_token_num;
    const auto h_token_num         = context_token_num + generate_batch_size;

    RUNTIME_ASSERT_OP_ARG(!params.residual, "default attention layer impl does not support residual!");

    const auto& layer_kv_cache = params.common.kv_cache;
    if (layer_kv_cache.has_value()) {
        const auto& kv_cache          = layer_kv_cache.value();
        const auto& kv_cache_block_id = *kv_cache.kv_cache_block_id;
        const auto& shape             = kv_cache.kv_cache_block_id->shape();
        RUNTIME_ASSERT_OP_ARG(((shape.size() == 2) && (shape[0] == input_lengths.shape()[0])),
                            "kv_cache_block_id shape in attention layer should be [batch_size, block_length]"
                            ", but got %s",
                            kv_cache_block_id.debugString().c_str());
        RUNTIME_ASSERT_OP_ARG(kv_cache.k_cache_buffer && kv_cache.v_cache_buffer,
                            "kv cache buffer should has value when use kv_cache_block_id");
        const auto& k_cache_shape = kv_cache.k_cache_buffer->shape();
        const auto& v_cache_shape = kv_cache.v_cache_buffer->shape();
        RUNTIME_ASSERT_OP_ARG(
            ((k_cache_shape.size() == 3) && (v_cache_shape.size() == 3) && (k_cache_shape[0] == v_cache_shape[0])
            && (k_cache_shape[1] == v_cache_shape[1]) && (k_cache_shape[1] == params.configs.tokens_per_block)
            && (k_cache_shape[2] == params.configs.kv_lora_rank + params.configs.rope_head_dim) && (v_cache_shape[2] == 0)),
            "mla kv cache buffer check shape failed. k_cache_buffer: %s, v_cache_buffer: %s",
            kv_cache.k_cache_buffer->debugString().c_str(),
            kv_cache.v_cache_buffer->debugString().c_str());
        if (kv_cache.k_scale_buffer) {
            const auto& k_scale_shape = kv_cache.k_scale_buffer->shape();
            const auto& v_scale_shape = kv_cache.v_scale_buffer->shape();
            RUNTIME_ASSERT_OP_ARG(((k_scale_shape.size() == 2) && (v_scale_shape.size() == 3)
                                && (k_scale_shape[0] == v_scale_shape[0]) && (k_scale_shape[1] == v_scale_shape[1])
                                && (k_cache_shape[0] == k_scale_shape[0])
                                && (k_scale_shape[1] == params.configs.tokens_per_block)),
                                "kv scale check buffer failed. k_scale_buffer: %s, v_scale_buffer: %s",
                                kv_cache.k_scale_buffer->debugString().c_str(),
                                kv_cache.v_scale_buffer->debugString().c_str());
        }
    }
    BufferPtr fused_qkv = nullptr;
    BufferPtr q = nullptr;
    int64_t kv_offset = 0;
    DevicePerfWrapper pre_mla_wrapper(this, "pre_mla_layer");
    if (params.weights.fusedqkrope_weight != nullptr) {
        // auto q_output_size = params.configs.nope_head_dim;
        fused_qkv                = gemm(GemmParams(input, *(params.weights.fusedqkrope_weight->kernel)));
        kv_offset = params.configs.q_lora_rank;
        auto norm_output         = layernormWithStride(LayernormWithStrideParams(
            {fused_qkv,
                     mayGetRef(params.weights.q_a_norm_weight),
                     params.ln_params.eps,
                     params.ln_params.norm_type,
                     0,
                     params.configs.q_lora_rank,
                     QScheme::NoQuantize,
                     false}));
        q                        = gemm(GemmParams(*norm_output.output, *(params.weights.q_b_weight->kernel)));
    } else {
        fused_qkv                = gemm(GemmParams(input, *(params.weights.fusedqkrope_no_lora_weight->kernel)));
        kv_offset = params.configs.head_num * params.configs.size_per_head;
        q = slice(SliceParams({*fused_qkv, -1, 0, (int64_t)(params.configs.head_num * params.configs.size_per_head)}));
    }
    layernormWithStride(
        LayernormWithStrideParams({fused_qkv,
                                   mayGetRef(params.weights.kv_a_norm_weight),
                                   params.ln_params.eps,
                                   params.ln_params.norm_type,
                                   (size_t)kv_offset,
                                   params.configs.kv_lora_rank,
                                   QScheme::NoQuantize,
                                   true}));
    pre_mla_wrapper.stop();
    auto      dtype         = input.type();
    auto qkv_output = allocateBuffer({dtype, {h_token_num, params.configs.head_num * params.configs.v_head_dim}}, {"qkv_output"});
    if (generate_batch_size) {
        RTP_LLM_LOG_DEBUG("absorb decode mla attention");
        RTP_LLM_CHECK_WITH_INFO(layer_kv_cache.has_value(), "kv cache can not be null for mla attention layer");
        auto generate_q = q->view(0, generate_batch_size);
        auto generate_fused_qkv = fused_qkv->view(0, generate_batch_size);
        auto generate_qkv_output = qkv_output->slice(0, generate_batch_size);
        mlaAbsorbAttention({params.layer_id,
                                 generate_q,
                                 generate_fused_qkv,
                                 kv_offset,
                                 generate_qkv_output,
                                 params.common,
                                 params.weights,
                                 params.configs,
                                 params.qscheme});
    }

    if (context_batch_size) {
        bool use_absorb_attention = params.common.max_prefix_length > 0;
        if (use_absorb_attention) {
            RTP_LLM_LOG_DEBUG("absorb context mla attention");
            RTP_LLM_CHECK_WITH_INFO(layer_kv_cache.has_value(), "kv cache can not be null for mla attention layer");
            auto generate_q = q->view(generate_batch_size, context_token_num);
            auto generate_fused_qkv = fused_qkv->view(generate_batch_size, context_token_num);
            auto generate_qkv_output = qkv_output->slice(generate_batch_size, context_token_num);
            if (layer_kv_cache) {
                auto layer_kv_cache_block_id = layer_kv_cache->kv_cache_block_id;
                params.common.kv_cache->kv_cache_block_id =
                    layer_kv_cache_block_id->slice(generate_batch_size, context_batch_size);
            }
            mlaAbsorbAttention({params.layer_id,
                                    generate_q,
                                    generate_fused_qkv,
                                    kv_offset,
                                    generate_qkv_output,
                                    params.common,
                                    params.weights,
                                    params.configs,
                                    params.qscheme,
                                    true});
        } else {
            RTP_LLM_LOG_DEBUG("no absorb context mla attention");
            // slice to get BufferPtr
            auto context_qkv_output = qkv_output->slice(generate_batch_size, context_token_num);
            auto context_fused_qkv = fused_qkv->slice(generate_batch_size, context_token_num);
            auto context_q = q->view(generate_batch_size, context_token_num);
            if (layer_kv_cache) {
                auto layer_kv_cache_block_id = layer_kv_cache->kv_cache_block_id;
                params.common.kv_cache->kv_cache_block_id =
                    layer_kv_cache_block_id->slice(generate_batch_size, context_batch_size);
            }
            mlaContextAttention({params.layer_id,
                                context_q,
                                *context_fused_qkv,
                                kv_offset,
                                context_qkv_output,
                                params.common,
                                params.weights,
                                params.configs,
                                params.qscheme});
        }
    }
  
    printBufferData(*qkv_output, "attent_proj_input");
    auto output_gemm_params = GemmParams(*qkv_output, *(params.weights.output_weight->kernel));
    auto attention_out = loraLinear(LoraLinearParams(output_gemm_params, params.common.lora_input.out_lora_input)).output;
    printBufferData(*attention_out, "attention_out");
    return {std::move(attention_out)};
}