def __init__()

in optimum/tpu/modeling_gemma.py [0:0]


    def __init__(self, config: GemmaConfig, rank: Optional[int] = None, world_size: Optional[int] = None):
        super().__init__(config)
        if rank is None:
            self.rank = get_model_parallel_rank()
        else:
            self.rank = rank
        if world_size is None:
            self.world_size = get_model_parallel_world_size()
        else:
            self.world_size = world_size
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [
                GemmaDecoderLayer(config, layer_idx, rank, world_size)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()