in optimum/tpu/modeling_gemma.py [0:0]
def __init__(self, config, rank=None, world_size=None):
super().__init__(config)
if rank is None:
self.rank = get_model_parallel_rank()
else:
self.rank = rank
if world_size is None:
self.world_size = get_model_parallel_world_size()
else:
self.world_size = world_size
self.model = GemmaModel(config, rank, world_size)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
# Initialize weights and apply final processing
self.post_init()