def tree_attn_cpu()

in python/tvm/relax/frontend/nn/llm/tree_attn.py [0:0]


def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]):
    """Generate tree attention kernel for batched tree attention.

    Parameters
    ----------
    h_kv : int
        Number of heads for key and value.
    h_q : int
        Number of heads for query.
    d : int
        Hidden dimension.
    dtype : str
        Data type.
    target : Target
        The target device.

    Returns
    -------
    mod : tvm.IRModule
        The generated IR module.
    """
    group_size = h_q // h_kv

    # fmt: off
    @T.prim_func
    def batch_tree_attn(  # pylint: disable=too-many-branches,line-too-long
        var_q: T.handle,  # [total_len, h_q, d]
        var_q_indptr: T.handle,  # [batch_size + 1]
        var_k: T.handle,  # [total_len, h_kv, d]
        var_v: T.handle,  # [total_len, h_kv, d]
        var_kv_indptr: T.handle,  # [batch_size + 1], kv_indptr should be the same as q_indptr in this case
        var_q_rope_position: T.handle,  # [total_q_len]
        var_mn_indptr: T.handle,  # [batch_size + 1]
        var_mask: T.handle,  # [mn_indptr[batch_size]]
        var_output: T.handle,  # [total_len, h_q, d]
        var_lse: T.handle,  # [total_len, h_q]
        rotary_mode: T.int32,
        rope_scale: T.float32,
        rope_theta: T.float32,
        sm_scale: T.float32,
    ):
        qo_len = T.int32(is_size_var=True)
        kv_len = T.int32(is_size_var=True)
        q_indptr_elem_offset = T.int32(is_size_var=True)
        kv_indptr_elem_offset = T.int32(is_size_var=True)
        q_rope_position_elem_offset = T.int32(is_size_var=True)
        mn_indptr_elem_offset = T.int32(is_size_var=True)
        mask_elem_offset = T.int32(is_size_var=True)
        tree_size = T.int32(is_size_var=True)
        batch_size_plus_1 = T.int32(is_size_var=True)

        q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
        q_indptr = T.match_buffer(
            var_q_indptr, (batch_size_plus_1,), "int32", elem_offset=q_indptr_elem_offset
        )
        k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
        v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
        kv_indptr = T.match_buffer(
            var_kv_indptr, (batch_size_plus_1,), "int32", elem_offset=kv_indptr_elem_offset
        )
        q_rope_position = T.match_buffer(
            var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset
        )
        mn_indptr = T.match_buffer(
            var_mn_indptr, (batch_size_plus_1,), "int32", elem_offset=mn_indptr_elem_offset
        )
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset)
        output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
        lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: disable=unused-variable

        for b in T.serial(batch_size_plus_1 - 1):
            with T.block("attn"):

                softmax_sum = T.alloc_buffer([h_q], "float32")
                m_prev = T.alloc_buffer([h_q], "float32")
                m_new = T.alloc_buffer([h_q], "float32")
                d_prev = T.alloc_buffer([h_q], "float32")
                d_new = T.alloc_buffer([h_q], "float32")
                p_sum = T.alloc_buffer([d], "float32")

                max_score = T.alloc_buffer([h_q], "float32")
                attention_scores = T.alloc_buffer([kv_len, h_q], "float32")
                exp_scores = T.alloc_buffer([kv_len, h_q], "float32")
                attention_score = T.alloc_buffer(
                    [
                        1,
                    ],
                    "float32",
                )
                query_val = T.alloc_buffer(
                    [
                        1,
                    ],
                    "float32",
                )
                key_val = T.alloc_buffer(
                    [
                        1,
                    ],
                    "float32",
                )
                result = T.alloc_buffer(
                    [
                        1,
                    ],
                    "float32",
                )

                for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
                    for i in T.serial(h_q):
                        max_score[i] = -5e4
                        m_prev[i] = -5e4
                        d_prev[i] = 1.0

                    for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
                        for h in T.serial(h_q):
                            h_kv_idx = h // group_size

                            if _check_tree_order(
                                row=q_idx,
                                col=k_idx,
                                batch=b,
                                tree_order=mask,
                                tree_order_indptr=mn_indptr,
                                kv_len=kv_indptr[b + 1] - kv_indptr[b],
                                qo_len=q_indptr[b + 1] - q_indptr[b],
                            ):
                                result[0] = 0.0
                                for d_idx in T.serial(d):
                                    query_val[0] = T.if_then_else(
                                        rotary_mode == 1,
                                        _rope(
                                            q,
                                            q_rope_position[q_indptr[b] + q_idx],
                                            d,
                                            rope_theta,
                                            rope_scale,
                                            (q_indptr[b] + q_idx, h, d_idx),
                                            dtype,
                                            rope_scaling,
                                        ),
                                        q[q_indptr[b] + q_idx, h, d_idx],
                                    )

                                    key_val[0] = T.if_then_else(
                                        rotary_mode == 1,
                                        _rope(
                                            k,
                                            q_rope_position[kv_indptr[b] + k_idx],
                                            d,
                                            rope_theta,
                                            rope_scale,
                                            (kv_indptr[b] + k_idx, h_kv_idx, d_idx),
                                            dtype,
                                            rope_scaling,
                                        ),
                                        k[kv_indptr[b] + k_idx, h_kv_idx, d_idx],
                                    )

                                    result[0] += query_val[0] * key_val[0]
                                attention_score[0] = (
                                    result[0] * math.log2(math.exp(1)) * sm_scale
                                )
                            else:
                                attention_score[0] = -5e4 * math.log2(math.exp(1)) * sm_scale
                            attention_scores[k_idx, h] = attention_score[0]
                            max_score[h] = T.max(max_score[h], attention_score[0])
                            m_new[h] = T.max(m_prev[h], max_score[h])

                    for h in T.serial(h_q):
                        d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])

                    for h in T.serial(h_q):
                        softmax_sum[h] = 0.0
                        for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
                            exp_scores[k_idx, h] = T.exp2(attention_scores[k_idx, h] - m_new[h])
                            softmax_sum[h] += exp_scores[k_idx, h]
                        d_new[h] += softmax_sum[h]
                    d_prev = d_new
                    m_prev = m_new

                    for h in T.serial(h_q):
                        h_kv_idx = h // group_size
                        for i in T.serial(d):
                            p_sum[i] = 0.0
                        for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
                            weight = exp_scores[v_idx, h] / d_new[h]
                            for i in T.serial(d):
                                p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight
                        for i in T.serial(d):
                            output[q_indptr[b] + q_idx, h, i] = p_sum[i]
                        lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h])

    # fmt: on
    # pylint: enable=line-too-long,too-many-branches
    return batch_tree_attn