void ArmCpuDevice::runOneBatchStride()

in maga_transformer/cpp/devices/arm_impl/ArmAttentionOp.cc [481:703]


void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) {
    auto datatype      = params.input.type();
    auto head_num      = params.configs.head_num;
    auto kv_head_num   = params.configs.kv_head_num;
    auto size_per_head = params.configs.size_per_head;
    std::chrono::steady_clock::time_point tStart, tEnd;
    std::chrono::microseconds diff;

    if (datatype != DataType::TYPE_FP32) {
        throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
    }
    // if (!params.common.kv_cache.has_value()) {
    //     throw std::runtime_error("kv cache block pointers can not be null");
    // }

    // Retrieve q/k/v by stride and not to split.
    auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head);
    printBufferData(params.input, "qkv");

    tStart = std::chrono::steady_clock::now();
    void *q_array[head_num], *k_array[head_num], *v_array[head_num];
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 0);

    tStart = std::chrono::steady_clock::now();
    //if (params.weights.qkv_weight->bias) {
    if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) {
        auto bias_data_type = params.weights.qkv_weight->bias->type();
        if (bias_data_type == DataType::TYPE_FP32) {
            addQKVBias<float, float>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
        } else if (bias_data_type == DataType::TYPE_FP16) {
            addQKVBias<float, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
        } else {
            throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
        }
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 1);
    printBufferData(params.input, "biased qkv");

    tStart = std::chrono::steady_clock::now();
    if (params.configs.rope_config.style != RopeStyle::No) {
        if (params.configs.rope_config.style == RopeStyle::Base) {
            halfRopeQK<float>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
                                params.configs.rope_config.base,
                                params.configs.rope_config.dim);
        } else {
            throw std::runtime_error("SelfAttention RoPE type is not supported");
        }
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 2);
    printBufferData(params.input, "roped qkv");

    tStart = std::chrono::steady_clock::now();
    if (params.common.kv_cache.has_value()) {
        updateKVCacheStride<float>(params, qkv, batch, seq_len, step);
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 8);

    BufferPtr k_buffer, v_buffer;
    void *k_in, *v_in;
    int stride_kv;
    if (step == 0) {
        /* Context */
        k_in = (float *)qkv + head_num * size_per_head;
        v_in = (float *)qkv + (head_num + kv_head_num) * size_per_head;
        stride_kv = (head_num + 2 * kv_head_num) * size_per_head;

        step = seq_len - 1; // Trick to unify context and decoder processes.
    } else {
        /* Decoder */
        k_buffer = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
        v_buffer = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});

        assemCache(params, batch, k_buffer, v_buffer);
        printBufferData(*k_buffer, "k_buffer");
        printBufferData(*v_buffer, "v_buffer");
        k_in = k_buffer->data();
        v_in = v_buffer->data();
        stride_kv = kv_head_num * size_per_head;
    }

    tStart = std::chrono::steady_clock::now();

    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 3);

    tStart = std::chrono::steady_clock::now();
    /* Re-init arrary as per [batch, num_heads]. */
    getPerHeadArray<float>(qkv, k_in, v_in, q_array, k_array, v_array,
                            head_num, kv_head_num, size_per_head);

    auto qk_output = allocateBuffer({datatype, {1, head_num, seq_len, step + 1}, AllocationType::HOST}, {"qk_output"});
    const int stride_q = (head_num + 2 * kv_head_num) * size_per_head;
    //const int N = 1 * head_num;
    //parallel_for(N, [&](int tid) {
    const int MHA_HEADS = 1 * head_num;
    parallel_for(MHA_HEADS, [&](int tid) {
        cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, seq_len, step + 1, size_per_head, 1.0,
            (const float*)q_array[tid], stride_q,
            (const float*)k_array[tid], stride_kv, 0.0,
            (float*)qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1);
    });

    printBufferData(*qk_output, "qk_output");
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 4);

    tStart = std::chrono::steady_clock::now();
    float scale = (1.0f / sqrtf(size_per_head * 1.0f));

    BufferPtr softmax_qk_output;
    if (seq_len == 1) {
        /* Decoder */
        softmax_qk_output = softmax({qk_output, std::nullopt, std::nullopt, scale});
    } else {
        /* Context */
        RUNTIME_ASSERT_OP_ARG(params.common.attention_mask,
                          "attention_mask must be provided for default context attention implementation");
        auto attention_mask = (*params.common.attention_mask).view(batch, 1);
        printBufferData(attention_mask, "attention_mask");
        softmax_qk_output = softmax({qk_output, attention_mask, std::nullopt, scale});
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 5);

    printBufferData(*softmax_qk_output, "softmax_qk_output");

    tStart = std::chrono::steady_clock::now();
    //const int NN = 1 * head_num;
    //parallel_for(NN, [&](int tid) {
    //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
    //        (const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
    //        (const float*)v_array[tid], stride_kv, 0.0,
    //        (float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
    //});
    if (!isKAIenabled) {
        parallel_for(MHA_HEADS, [&](int tid) {
            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
                (const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
                (const float*)v_array[tid], stride_kv, 0.0,
                (float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
        });
    } else {
        if (seq_len == 1) {
            /* Decoder has higher performance with cblas for gemm. */
            parallel_for(MHA_HEADS, [&](int tid) {
                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, seq_len, size_per_head, step + 1, 1.0,
                    (const float*)softmax_qk_output->dataWithOffset(tid * seq_len * (step + 1)), step + 1,
                    (const float*)v_array[tid], stride_kv, 0.0,
                    (float*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), head_num * size_per_head);
            });
        } else {
            /* Context has higher performance with KleidiAI for gemm. */
            const size_t bias_size = size_per_head;
            float* bias = new float[bias_size];
            memset(bias, 0, bias_size * sizeof(float));
            const size_t M = seq_len;
            const size_t N = size_per_head;
            const size_t K = step + 1;
            const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
            const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
            const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
            const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
            const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(M, K, mr, kr, sr);
            const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(N, K, nr, kr);
            const size_t lhs_stride = K * sizeof(float);
            const size_t rhs_stride = stride_kv * sizeof(float);
            const size_t dst_stride_row = (head_num * size_per_head) * sizeof(float);
            const size_t dst_stride_col = sizeof(float);

            parallel_for(MHA_HEADS, [&](int tid) {
                uint8_t *lhs_packed = new uint8_t[lhs_packed_size];
                uint8_t *rhs_packed = new uint8_t[rhs_packed_size];
                kai_run_lhs_quant_pack_bf16p8x4_f32_neon(M, K, mr, kr, sr, 0,
                    (const void*)softmax_qk_output->dataWithOffset(tid * M * K), lhs_stride, lhs_packed);
                kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(1,
                    N,
                    K,
                    nr,
                    kr,
                    sr,
                    rhs_stride,
                    (const void*)v_array[tid],
                    bias,
                    NULL,
                    rhs_packed,
                    0,
                    NULL);
                kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(M, N, K,
                    lhs_packed,
                    rhs_packed,
                    (void*)params.output.dataWithOffset(past_seq * head_num * size_per_head + tid * size_per_head), dst_stride_row, dst_stride_col,
                    -FLT_MAX, FLT_MAX);
                delete[] rhs_packed;
                delete[] lhs_packed;
            });

            delete[] bias;
        }
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 6);

    tStart = std::chrono::steady_clock::now();
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 7);

    /* Print profile data at the end of operator unit test. */
    // if (a_cnt_[0] == 24)
    //     printStat();
}