in optimum/tpu/modeling_mistral.py [0:0]
def __init__(self, config: MistralConfig, rank: Optional[int] = None, world_size: Optional[int] = None):
super().__init__(config)
if rank is None:
rank = get_model_parallel_rank()
if world_size is None:
world_size = get_model_parallel_world_size()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[MistralDecoderLayer(config, layer_idx, rank, world_size) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()