def __init__()

in optimum/tpu/modeling_mistral.py [0:0]


    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None,
        rank: Optional[int] = None,
        world_size: Optional[int] = None,):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = ColumnParallelLinear.create(
            self.hidden_size,
            self.num_heads * self.head_dim,
            bias=False,
            world_size=world_size,
            rank=rank,
        )
        self.k_proj = ColumnParallelLinear.create(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            world_size=world_size,
            rank=rank,
        )
        self.v_proj = ColumnParallelLinear.create(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            world_size=world_size,
            rank=rank,
        )
        self.o_proj = RowParallelLinear.create(
            self.hidden_size,
            self.hidden_size,
            bias=False,
            world_size=world_size,
            rank=rank,
        )

        self.rotary_emb = MistralRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )