in optimum/neuron/models/inference/backend/modules/attention/attention_base.py [0:0]
def get_flash_attention_strategy(self, q_len) -> FlashAttentionStrategy:
"""
Gets the flash attention strategy.
For LNC1, use the unsharded kernel if sequence length is at least 4096 to get the best performance.
The unsharded kernel requires a sequence length of at least 512.
For LNC2, use the sharded kernel if sequence length is divisible by 1024. Otherwise, use no
kernel, because the unsharded kernel has worse performance than no kernel.
The sharded kernel requires a sequence length of at least 1024.
These constraints may change later.
TODO: Throw an exception instead of disabling flash attention if explicitly enabled but not eligible.
This must consider bucketing to avoid throwing an exception for smaller buckets.
"""
if self.qk_scale is not None:
# If a custom qk_scale is provided, flash attention is not supported.
return FlashAttentionStrategy.NONE
if int(self.logical_nc_config) > 1:
if q_len < 1024:
return FlashAttentionStrategy.NONE
if q_len % 1024 == 0:
return FlashAttentionStrategy.SHARDED_KERNEL
else:
warnings.warn("Flash attention disabled. LNC2 requires seq_len % 1024 for flash attn to be performant")
return FlashAttentionStrategy.NONE
# If seq_len is at least 4096, enable flash attn automatically to improve performance.
if q_len >= 4096:
return FlashAttentionStrategy.UNSHARDED_KERNEL
# At lower seq lens, enable only if explicitly enabled.
if self.attn_kernel_enabled and q_len >= 512:
return FlashAttentionStrategy.UNSHARDED_KERNEL
return FlashAttentionStrategy.NONE