void ArmCpuDevice::runOneBatch()

in maga_transformer/cpp/devices/arm_impl/ArmAttentionOp.cc [189:406]


void ArmCpuDevice::runOneBatch(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) {
    auto datatype      = params.input.type();
    auto head_num      = params.configs.head_num;
    auto kv_head_num   = params.configs.kv_head_num;
    auto size_per_head = params.configs.size_per_head;

    // if (!params.common.kv_cache.has_value()) {
    //     throw std::runtime_error("kv cache block pointers can not be null");
    // }

    std::chrono::steady_clock::time_point tStart, tEnd;
    std::chrono::microseconds diff;

    // qkv to q_output, k_output, v_output
    auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head);
    printBufferData(params.input, "qkv");

    tStart = std::chrono::steady_clock::now();
    //if (params.weights.qkv_weight->bias) {
    if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) {
        if (datatype == DataType::TYPE_FP32) {
            auto bias_data_type = params.weights.qkv_weight->bias->type();
            if (bias_data_type == DataType::TYPE_FP32) {
                addQKVBias<float, float>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
            } else if (bias_data_type == DataType::TYPE_FP16) {
                addQKVBias<float, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
            } else {
                throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
            }
        } else if (datatype == DataType::TYPE_FP16) {
            addQKVBias<__fp16, __fp16>(qkv, params.weights.qkv_weight->bias->data(), 1, seq_len, head_num, kv_head_num, size_per_head);
        } else {
            throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
        }
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 1);
    printBufferData(params.input, "biased qkv");

    tStart = std::chrono::steady_clock::now();
    if (params.configs.rope_config.style != RopeStyle::No) {
        if (params.configs.rope_config.style == RopeStyle::Base) {
            if (datatype == DataType::TYPE_FP32) {
                halfRopeQK<float>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
                                    params.configs.rope_config.base, params.configs.rope_config.dim);
            } else if (datatype == DataType::TYPE_FP16) {
                halfRopeQK<__fp16>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step,
                                    params.configs.rope_config.base, params.configs.rope_config.dim);
            } else {
                throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
            }
        } else {
            throw std::runtime_error("SelfAttention RoPE type is not supported");
        }
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 2);
    printBufferData(params.input, "roped qkv");

    tStart = std::chrono::steady_clock::now();
    arm_compute::DataType acl_data_type = getAclDataType(datatype);
    arm_compute::NESplit    split;
    arm_compute::Tensor     src;
    arm_compute::TensorInfo src_info = arm_compute::TensorInfo(
        arm_compute::TensorShape(size_per_head, head_num + 2 * kv_head_num, seq_len, 1), 1, acl_data_type);

    src.allocator()->init(src_info);
    src.allocator()->import_memory(qkv);
    std::vector<arm_compute::Tensor>   dsts{};
    std::vector<arm_compute::ITensor*> dsts_ptr;

    arm_compute::TensorInfo q_info, kv_info;
    arm_compute::Tensor     q, k, v;
    q_info = arm_compute::TensorInfo(
        arm_compute::TensorShape(size_per_head, head_num, seq_len, 1), 1, acl_data_type);
    kv_info = arm_compute::TensorInfo(
        arm_compute::TensorShape(size_per_head, kv_head_num, seq_len, 1), 1, acl_data_type);

    auto q_input = allocateBuffer({datatype, {1, seq_len, head_num, size_per_head}, AllocationType::HOST}, {});
    auto k_input = allocateBuffer({datatype, {1, seq_len, kv_head_num, size_per_head}, AllocationType::HOST}, {});
    auto v_input = allocateBuffer({datatype, {1, seq_len, kv_head_num, size_per_head}, AllocationType::HOST}, {});

    q.allocator()->init(q_info);
    k.allocator()->init(kv_info);
    v.allocator()->init(kv_info);
    q.allocator()->import_memory(q_input->data());
    dsts.push_back(std::move(q));

    k.allocator()->import_memory(k_input->data());
    dsts.push_back(std::move(k));

    v.allocator()->import_memory(v_input->data());
    dsts.push_back(std::move(v));

    for (auto& dst : dsts) {
        dsts_ptr.emplace_back(&dst);
    }
    split.configure(&src, dsts_ptr, 1);
    split.run();
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 0);

    printBufferData(*q_input, "q_input");
    printBufferData(*k_input, "k_input");
    printBufferData(*v_input, "v_input");

    tStart = std::chrono::steady_clock::now();
    if (params.common.kv_cache.has_value()) {
        updateKVCache(params, batch, step, k_input, v_input);
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 8);

    BufferPtr k_in, k_out, v_in, v_out;
    if (step == 0) {
        /* Context. */
        k_in = k_input;
        v_in = v_input;

        k_out = allocateBuffer({datatype, {1, kv_head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
        v_out = allocateBuffer({datatype, {1, kv_head_num, seq_len, size_per_head}, AllocationType::HOST}, {});
    } else {
        /* Decoder. Retrieve k/v cache data. */
        k_in = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
        v_in = allocateBuffer({datatype, {1, step + 1, kv_head_num, size_per_head}, AllocationType::HOST}, {});
        assemCache(params, batch, k_in, v_in);

        printBufferData(*k_in, "k_in");
        printBufferData(*v_in, "v_in");

        k_out = allocateBuffer({datatype, {1, kv_head_num, step + 1, size_per_head}, AllocationType::HOST}, {});
        v_out = allocateBuffer({datatype, {1, kv_head_num, step + 1, size_per_head}, AllocationType::HOST}, {});
    }

    tStart = std::chrono::steady_clock::now();
    auto q_output = allocateBuffer({datatype, {1, head_num, seq_len, size_per_head}, AllocationType::HOST}, {});

    transposeDim12(std::move(q_input), q_output->data());
    transposeDim12(std::move(k_in), k_out->data());
    transposeDim12(std::move(v_in), v_out->data());
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 3);

    printBufferData(*q_output, "q_output");
    printBufferData(*k_out, "k_out");
    printBufferData(*v_out, "v_out");

    if (kv_head_num != head_num) {
        /* repeat K/V */
        size_t len;
        if (step == 0) {
            len = seq_len;
        } else {
            len = step + 1;
        }
        auto k_repeat = allocateBuffer({datatype, {1, head_num, len, size_per_head}, AllocationType::HOST}, {});
        auto v_repeat = allocateBuffer({datatype, {1, head_num, len, size_per_head}, AllocationType::HOST}, {});
        auto n_rep = head_num / kv_head_num;
        const int N = kv_head_num;
        parallel_for(N, [&](int tid) {
            for (int i = 0; i < n_rep; i++) {
                memcpy(k_repeat->dataWithOffset((tid * n_rep + i) * len * size_per_head), k_out->dataWithOffset(tid * len * size_per_head), k_out->sizeBytes() / kv_head_num);
                memcpy(v_repeat->dataWithOffset((tid * n_rep + i) * len * size_per_head), v_out->dataWithOffset(tid * len * size_per_head), v_out->sizeBytes() / kv_head_num);
            }
        });
        k_out = std::move(k_repeat);
        v_out = std::move(v_repeat);
    }

    tStart = std::chrono::steady_clock::now();
    auto qk_output = gemm_acl({*q_output,
                           *k_out,
                           std::nullopt,
                           nullptr,
                           DataType::TYPE_INVALID,
                           TransposeOperation::NONE,
                           TransposeOperation::TRANSPOSE});
    printBufferData(*qk_output, "qk_output: ");
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 4);

    printBufferData(*qk_output, "qk_output");

    tStart = std::chrono::steady_clock::now();
    float scale = (1.0f / sqrtf(size_per_head * 1.0f));

    BufferPtr softmax_qk_output;
    if (seq_len == 1) {
        /* Decoder */
        softmax_qk_output = softmax({qk_output, std::nullopt, std::nullopt, scale});
    } else {
        /* Context */
        auto attention_mask = (*params.common.attention_mask).view(batch, 1);
        softmax_qk_output = softmax({qk_output, attention_mask, std::nullopt, scale});
    }
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 5);
    printBufferData(*softmax_qk_output, "softmax_qk_output");

    tStart = std::chrono::steady_clock::now();
    auto qkv_output = gemm_acl({*softmax_qk_output, *v_out});
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 6);

    tStart = std::chrono::steady_clock::now();
    transposeDim12(qkv_output, params.output.dataWithOffset(past_seq * head_num * size_per_head));
    tEnd = std::chrono::steady_clock::now();
    diff = std::chrono::duration_cast<std::chrono::microseconds>(tEnd - tStart);
    logTime(diff, 7);
}