AttentionModuleOutput ROCmDevice::contextAttention()

in maga_transformer/cpp/devices/rocm_impl/ROCmAttentionOp.cc [77:305]


AttentionModuleOutput ROCmDevice::contextAttention(const AttentionModuleParams& params) {
    auto datatype       = params.input.type();
    auto token_num      = params.input.shape()[0];
    auto batch_size     = params.common.context_batch_size;
    auto decoder_batch_size = params.common.decoder_batch_size;
    auto seq_len        = params.common.context_max_seq_len;
    auto seq_len_with_prefix = seq_len + params.common.max_prefix_length;
    // auto context_token_num   = params.common.context_token_num;
    auto head_num       = params.configs.head_num;
    auto kv_head_num    = params.configs.kv_head_num;
    auto size_per_head  = params.configs.size_per_head;

    auto q_output = allocateBuffer({params.input.type(),
                                    {batch_size, head_num, seq_len, size_per_head},
                                    AllocationType::DEVICE},
                                    {"q_output"});

    auto k_output = allocateBuffer({params.input.type(),
                                    {batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
                                    AllocationType::DEVICE},
                                    {"k_output"});

    auto v_output = allocateBuffer({params.input.type(),
                                    {batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
                                    AllocationType::DEVICE},
                                    {"v_output"});

    BufferPtr kv_cache_block_id = nullptr;

    KVBlockArray                  kv_block_array;
    PrefixPromptBatchWeightsParam prefix_prompt_param;

    if (params.common.kv_cache) {
        const auto max_blocks_per_batch = params.common.kv_cache->kv_cache_block_id->shape()[1];
        kv_cache_block_id =  allocateBuffer({DataType::TYPE_INT32, {batch_size, 1, 2, max_blocks_per_batch}, AllocationType::DEVICE},
                                         {"kv_cache_block_id"});

        kv_block_array = getKVBlockArray(params, *kv_cache_block_id, batch_size, false, stream_);

        prefix_prompt_param.kv_block_array           = kv_block_array;

        if (params.common.prefix_prompt_lengths) {
            prefix_prompt_param.d_prefix_prompt_lengths  = params.common.prefix_prompt_lengths->data<int>();
            prefix_prompt_param.max_prefix_prompt_length = params.common.max_prefix_length;
            prefix_prompt_param.count_length             = 1;
        }
    }
    printBufferData(*params.common.input_lengths, "input_lengths");
    if (params.common.cu_seqlens) {
        printBufferData(*params.common.cu_seqlens, "cu_seqlens");
        printBufferData(*params.common.cu_kv_seqlens, "cu_kv_seqlens");
    }
    printBufferData(params.input, "fa_input");

    // int8
    float* scale_out_ptr = nullptr;
    int    int8_mode     = 0;

   if (prefix_prompt_param.max_prefix_prompt_length > 0) {
        DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeLoadPrefixKVCache,
            q_output->data(),
            k_output->data(),
            v_output->data(),
            &prefix_prompt_param,
            batch_size,
            seq_len,
            head_num,
            kv_head_num,
            size_per_head,
            scale_out_ptr,
            int8_mode,
            stream_
        );
    }

    bool store_qkv = true;
    bool store_q = true;
    bool store_kv = true;
    bool store_cache = params.common.kv_cache.has_value();

    // if all condition satisfy, no need to do invokeAddFusedQKVBiasTranspose
    bool skip_add_bias_transpose = (params.configs.rope_config.style == RopeStyle::No &&
     !params.common.kv_cache &&
     !params.configs.fuse_qkv_add_bias);
    RTP_LLM_LOG_DEBUG("skip_add_bias_transpose: %d", skip_add_bias_transpose);
    if (!skip_add_bias_transpose)
    {
        DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeAddFusedQKVBiasTranspose,
            q_output->data(),
            k_output->data(),
            v_output->data(),
            &prefix_prompt_param,
            params.input.data(),
            nullptr,
            params.common.position_ids ? params.common.position_ids->dataWithOffset<int>(decoder_batch_size * params.configs.rope_config.index_factor): nullptr,
            params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias ? params.weights.qkv_weight->bias->data() : nullptr,
            params.common.padding_offset->data<int>(),
            params.common.cu_seqlens->data<int>(),
            batch_size,
            seq_len,
            token_num,
            head_num,
            kv_head_num,
            size_per_head,
            params.configs.rope_config,
            params.configs.use_logn_attn,
            scale_out_ptr,
            int8_mode,
            false,
            store_qkv,
            store_q,
            store_kv,
            store_cache,
            stream_
        );
        sync_check_cuda_error();
        
        writeCacheStore(params);
    }

    fmha_runner_->setup(datatype,
                        params.configs.mask_type,
                        head_num,
                        kv_head_num,
                        size_per_head,
                        params.configs.q_scaling);
    // auto seq_len_round_32 = (seq_len + 31) / 32 * 32;
    // auto softmax_lse_ = allocateBuffer({DataType::TYPE_FP32, // params.output.type(),
    //                                     {batch_size, head_num, seq_len_round_32},
    //                                     AllocationType::DEVICE},
    //                                     {"softmax_lse"});
    printBufferData(*q_output, "q_output");
    // printBufferData(*k_output, "k_output");
    // printBufferData(*v_output, "v_output");
    // if (v_output->shape()[0]>1) {
    //     printBufferData(*(v_output->index(1)), "v_output_batch1");
    // }


    const size_t hidden_units    = head_num * size_per_head;
    const size_t hidden_units_kv = kv_head_num * size_per_head;

    uint32_t lse_acc_buf_sz =
        fmha_runner_->runCKFmha(params.input.data(),
                                params.input.dataWithOffset(hidden_units),
                                params.input.dataWithOffset(hidden_units + hidden_units_kv),
                                params.output.data(),
                                nullptr,  // buffer for store out softmax_lse, looks like not used by RTP
                                batch_size,
                                seq_len,
                                // context_token_num,
                                params.common.cu_seqlens->data(),
                                params.common.cu_kv_seqlens->data(),
                                nullptr,
                                params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data() : nullptr,
                                nullptr);

    auto lse_acc_buf = allocateBuffer({DataType::TYPE_FP32, {lse_acc_buf_sz}, AllocationType::DEVICE}, {"lse_acc_buf"});

    if (fmha_runner_->runCKFmha(params.input.data(),
                                params.input.dataWithOffset(hidden_units),
                                params.input.dataWithOffset(hidden_units + hidden_units_kv),
                                params.output.data(),
                                nullptr,  // buffer for store out softmax_lse, looks like not used by RTP
                                batch_size,
                                seq_len,
                                // context_token_num,
                                params.common.cu_seqlens->data(),
                                params.common.cu_kv_seqlens->data(),
                                lse_acc_buf->data(),
                                params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data() : nullptr,
                                nullptr)) {

        return;
    } else {
        RTP_LLM_LOG_WARNING("ck fmha failed, falling to default implementation. This decreases performance drastically.");
        auto qk_output = gemm({*q_output,
                               *k_output,
                               std::nullopt,
                               nullptr,
                               DataType::TYPE_FP32,
                               TransposeOperation::NONE,
                               TransposeOperation::TRANSPOSE});
        printBufferData(*qk_output, "qk_output: ");

        float scale = (1.0f / sqrtf(size_per_head * 1.0f));

        // TODO(lidongjin): Only support float32(in)\float16(output).
        auto softmax_type = qk_output->type();
        auto lengths_host = clone({
            params.common.input_lengths->view(decoder_batch_size, batch_size),
            AllocationType::HOST
        });
        auto prefix_lengths_host = params.common.prefix_prompt_lengths
                                 ? clone({*params.common.prefix_prompt_lengths, AllocationType::HOST})
                                 : BufferPtr(new Buffer(MemoryType::MEMORY_CPU, DataType::TYPE_INVALID, {0}, nullptr));

        auto attention_mask = attentionMask({
            *lengths_host,
            *prefix_lengths_host,
            q_output->type(),
            params.configs.mask_type == AttentionMaskType::causalMask
        });
        RUNTIME_ASSERT_OP_ARG(
            params.common.attention_mask,
            "attention_mask must be provided for default context attention implementation");
        auto softmax_qk_output = softmax({std::move(qk_output),
                                        *attention_mask,
                                        nullopt,
                                        scale,
                                        datatype});
        printBufferData(*softmax_qk_output, "softmax_qk_output: ");

        auto qkv_output = gemm({*softmax_qk_output, *v_output});

        auto &qkv_transpose_output = params.output;

        DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeTransposeQKV,
            qkv_transpose_output.data(),
            qkv_output->data(),
            batch_size,
            seq_len,
            head_num,
            size_per_head,
            nullptr,
            0,
            stream_);
    }
}