def __init__()

in optimum/tpu/modeling_gemma.py [0:0]


    def __init__(self, config, rank=None, world_size=None):
        super().__init__(config)
        if rank is None:
            self.rank = get_model_parallel_rank()
        else:
            self.rank = rank
        if world_size is None:
            self.world_size = get_model_parallel_world_size()
        else:
            self.world_size = world_size
        self.model = GemmaModel(config, rank, world_size)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Initialize weights and apply final processing
        self._register_load_state_dict_pre_hook(self.load_hook)

        # Initialize weights and apply final processing
        self.post_init()