def __init__()

in optimum/tpu/modeling_mistral.py [0:0]


    def __init__(self, config: MistralConfig, rank: Optional[int] = None, world_size: Optional[int] = None):
        super().__init__(config)
        if rank is None:
            rank = get_model_parallel_rank()
        if world_size is None:
            world_size = get_model_parallel_world_size()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [MistralDecoderLayer(config, layer_idx, rank, world_size) for layer_idx in range(config.num_hidden_layers)]
        )
        self._attn_implementation = config._attn_implementation
        self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()