in optimum/habana/transformers/models/llama/modeling_llama.py [0:0]
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.rotary_emb = GaudiLlamaRotaryEmbedding(config=config)
self.num_key_value_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
self.dim1 = self.num_heads * self.head_dim
self.dim2 = config.num_key_value_heads * self.head_dim
self.qkv_proj = torch.nn.Linear(
self.hidden_size,
self.dim1 + 2 * self.dim2,
bias=config.attention_bias,
)
self.q_proj = None
self.k_proj = None
self.v_proj = None
self.inp_seq_len = -1
self.fused_scaled_dot_product_attention = (
ModuleFusedSDPA(
FusedSDPA,
scale=self.scaling,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
)
if FusedSDPA
else None
)
# for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices
self.fused_scaled_dot_product_attention_distributed = None
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
self.fused_scaled_dot_product_attention_distributed = (
GaudiDistributedAttention(
self.fused_scaled_dot_product_attention,
scale=self.scaling,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
)
if FusedSDPA
else None
)