def __init__()

in optimum/habana/transformers/models/llama/modeling_llama.py [0:0]


    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx)

        self.matmul_qk = Matmul()
        self.matmul_av = Matmul()
        self.k_cache = KVCache()
        self.v_cache = KVCache()

        self.rotary_emb = GaudiLlamaRotaryEmbedding(config=config)
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)

        if hasattr(config, "fused_qkv") and config.fused_qkv:
            self.num_heads = config.num_attention_heads
            self.head_dim = config.hidden_size // self.num_heads
            self.dim1 = self.num_heads * self.head_dim
            self.dim2 = config.num_key_value_heads * self.head_dim
            self.qkv_proj = torch.nn.Linear(
                self.hidden_size,
                self.dim1 + 2 * self.dim2,
                bias=config.attention_bias,
            )
            self.q_proj = None
            self.k_proj = None
            self.v_proj = None
        self.inp_seq_len = -1
        self.fused_scaled_dot_product_attention = (
            ModuleFusedSDPA(
                FusedSDPA,
                scale=self.scaling,
                attention_dropout=self.attention_dropout,
                enable_recompute=False,
                flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
            )
            if FusedSDPA
            else None
        )
        # for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices
        self.fused_scaled_dot_product_attention_distributed = None
        if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
            self.fused_scaled_dot_product_attention_distributed = (
                GaudiDistributedAttention(
                    self.fused_scaled_dot_product_attention,
                    scale=self.scaling,
                    attention_dropout=self.attention_dropout,
                    enable_recompute=False,
                    flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
                )
                if FusedSDPA
                else None
            )