in optimum/neuron/models/inference/backend/modules/attention/attention_base.py [0:0]
def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask) -> Tensor:
"""attention computation at prefilling (context encoding) phase"""
K_active = repeat_kv(K, self.num_key_value_groups)
V_active = repeat_kv(V, self.num_key_value_groups)
flash_attn_strategy = self.get_flash_attention_strategy(q_len)
logger.debug(f"Flash attention strategy: {flash_attn_strategy}")
if flash_attn_strategy != FlashAttentionStrategy.NONE:
logger.debug(f"ATTN kernel: logical_nc_config={self.logical_nc_config}")
# if we are using left padding, then the bzs needs be 1 (otherwise we get wrong result
# because flash attention does not use attention_mask). In practice, we use right
# padding so this is unlikely to cause issues
assert self.padding_side == "right" or bsz == 1
# original shape of q, k, v is BHSD, and expected output is also BHSD.
logger.debug(f"Using flash_fwd for Q.shape={Q.shape}")
# make sure to cast inputs to torch_dtype (this is needed because the downcast to bf16
# might happen after the kernel hlo creation step). Also convert shapes as expected by the kernel.
# original Q shape: batch, num_heads, seqlen, d_head
Q = (
Q.permute(0, 1, 3, 2) # after permute: batch, num_heads, d_head, seqlen
.reshape((bsz * self.num_heads, self.head_dim, q_len))
.to(self.torch_dtype)
)
Q = Q / math.sqrt(self.head_dim)
K_active = (
K_active.permute(0, 1, 3, 2).reshape((bsz * self.num_heads, self.head_dim, q_len)).to(self.torch_dtype)
)
V_active = V_active.reshape((bsz * self.num_heads, q_len, self.head_dim)).to(self.torch_dtype)
# shape: (B*H)DS
attn_output = torch.zeros(bsz * self.num_heads, self.head_dim, q_len, dtype=Q.dtype, device=Q.device)
logger.debug("Input parameter shapes")
logger.debug(f"Q input shape {Q.shape}")
logger.debug(f"K input shape {K_active.shape}")
logger.debug(f"V input shape {V_active.shape}")
logger.debug(f"Attn output shape {attn_output.shape}")
if flash_attn_strategy == FlashAttentionStrategy.SHARDED_KERNEL:
grid = (nc(self.logical_nc_config),)
_flash_fwd_call[grid](
Q,
K_active,
V_active,
1.0,
attn_output,
kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap",
)
elif flash_attn_strategy == FlashAttentionStrategy.UNSHARDED_KERNEL:
_flash_fwd_call(
Q,
K_active,
V_active,
1.0,
attn_output,
kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap",
)
else:
raise ValueError(f"Invalid flash attention strategy: {flash_attn_strategy}")
# shape: BHDS
attn_output = attn_output.reshape((bsz, self.num_heads, self.head_dim, q_len))
logger.debug(f"Attn output after reshape {attn_output.shape}")
else:
logger.debug("ATTN: native compiler")
logger.debug(f"Not using flash_fwd for Q.shape={Q.shape}")
active_scores = self.scaled_qk(Q, K_active, attention_mask)
active_scores = nn.functional.softmax(active_scores, dim=-1, dtype=torch.float32).to(Q.dtype)
attn_output = torch.matmul(active_scores, V_active)
return attn_output, flash_attn_strategy