in python/tvm/relax/frontend/nn/llm/tree_attn.py [0:0]
def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]):
"""Generate tree attention kernel for batched tree attention.
Parameters
----------
h_kv : int
Number of heads for key and value.
h_q : int
Number of heads for query.
d : int
Hidden dimension.
dtype : str
Data type.
target : Target
The target device.
Returns
-------
mod : tvm.IRModule
The generated IR module.
"""
group_size = h_q // h_kv
# fmt: off
@T.prim_func
def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long
var_q: T.handle, # [total_len, h_q, d]
var_q_indptr: T.handle, # [batch_size + 1]
var_k: T.handle, # [total_len, h_kv, d]
var_v: T.handle, # [total_len, h_kv, d]
var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case
var_q_rope_position: T.handle, # [total_q_len]
var_mn_indptr: T.handle, # [batch_size + 1]
var_mask: T.handle, # [mn_indptr[batch_size]]
var_output: T.handle, # [total_len, h_q, d]
var_lse: T.handle, # [total_len, h_q]
rotary_mode: T.int32,
rope_scale: T.float32,
rope_theta: T.float32,
sm_scale: T.float32,
):
qo_len = T.int32(is_size_var=True)
kv_len = T.int32(is_size_var=True)
q_indptr_elem_offset = T.int32(is_size_var=True)
kv_indptr_elem_offset = T.int32(is_size_var=True)
q_rope_position_elem_offset = T.int32(is_size_var=True)
mn_indptr_elem_offset = T.int32(is_size_var=True)
mask_elem_offset = T.int32(is_size_var=True)
tree_size = T.int32(is_size_var=True)
batch_size_plus_1 = T.int32(is_size_var=True)
q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
q_indptr = T.match_buffer(
var_q_indptr, (batch_size_plus_1,), "int32", elem_offset=q_indptr_elem_offset
)
k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
kv_indptr = T.match_buffer(
var_kv_indptr, (batch_size_plus_1,), "int32", elem_offset=kv_indptr_elem_offset
)
q_rope_position = T.match_buffer(
var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset
)
mn_indptr = T.match_buffer(
var_mn_indptr, (batch_size_plus_1,), "int32", elem_offset=mn_indptr_elem_offset
)
mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset)
output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable
for b in T.serial(batch_size_plus_1 - 1):
with T.block("attn"):
softmax_sum = T.alloc_buffer([h_q], "float32")
m_prev = T.alloc_buffer([h_q], "float32")
m_new = T.alloc_buffer([h_q], "float32")
d_prev = T.alloc_buffer([h_q], "float32")
d_new = T.alloc_buffer([h_q], "float32")
p_sum = T.alloc_buffer([d], "float32")
max_score = T.alloc_buffer([h_q], "float32")
attention_scores = T.alloc_buffer([kv_len, h_q], "float32")
exp_scores = T.alloc_buffer([kv_len, h_q], "float32")
attention_score = T.alloc_buffer(
[
1,
],
"float32",
)
query_val = T.alloc_buffer(
[
1,
],
"float32",
)
key_val = T.alloc_buffer(
[
1,
],
"float32",
)
result = T.alloc_buffer(
[
1,
],
"float32",
)
for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
for i in T.serial(h_q):
max_score[i] = -5e4
m_prev[i] = -5e4
d_prev[i] = 1.0
for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
for h in T.serial(h_q):
h_kv_idx = h // group_size
if _check_tree_order(
row=q_idx,
col=k_idx,
batch=b,
tree_order=mask,
tree_order_indptr=mn_indptr,
kv_len=kv_indptr[b + 1] - kv_indptr[b],
qo_len=q_indptr[b + 1] - q_indptr[b],
):
result[0] = 0.0
for d_idx in T.serial(d):
query_val[0] = T.if_then_else(
rotary_mode == 1,
_rope(
q,
q_rope_position[q_indptr[b] + q_idx],
d,
rope_theta,
rope_scale,
(q_indptr[b] + q_idx, h, d_idx),
dtype,
rope_scaling,
),
q[q_indptr[b] + q_idx, h, d_idx],
)
key_val[0] = T.if_then_else(
rotary_mode == 1,
_rope(
k,
q_rope_position[kv_indptr[b] + k_idx],
d,
rope_theta,
rope_scale,
(kv_indptr[b] + k_idx, h_kv_idx, d_idx),
dtype,
rope_scaling,
),
k[kv_indptr[b] + k_idx, h_kv_idx, d_idx],
)
result[0] += query_val[0] * key_val[0]
attention_score[0] = (
result[0] * math.log2(math.exp(1)) * sm_scale
)
else:
attention_score[0] = -5e4 * math.log2(math.exp(1)) * sm_scale
attention_scores[k_idx, h] = attention_score[0]
max_score[h] = T.max(max_score[h], attention_score[0])
m_new[h] = T.max(m_prev[h], max_score[h])
for h in T.serial(h_q):
d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])
for h in T.serial(h_q):
softmax_sum[h] = 0.0
for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
exp_scores[k_idx, h] = T.exp2(attention_scores[k_idx, h] - m_new[h])
softmax_sum[h] += exp_scores[k_idx, h]
d_new[h] += softmax_sum[h]
d_prev = d_new
m_prev = m_new
for h in T.serial(h_q):
h_kv_idx = h // group_size
for i in T.serial(d):
p_sum[i] = 0.0
for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
weight = exp_scores[v_idx, h] / d_new[h]
for i in T.serial(d):
p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight
for i in T.serial(d):
output[q_indptr[b] + q_idx, h, i] = p_sum[i]
lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h])
# fmt: on
# pylint: enable=line-too-long,too-many-branches
return batch_tree_attn