in optimum/tpu/modeling_llama.py [0:0]
def __init__(self, config: LlamaConfig, rank: Optional[int] = None, world_size: Optional[int] = None):
super().__init__(config)
if rank is None:
self.rank = get_model_parallel_rank()
else:
self.rank = rank
if world_size is None:
self.world_size = get_model_parallel_world_size()
else:
self.world_size = world_size
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx, rank, world_size) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()