in optimum/tpu/modeling_mistral.py [0:0]
def __init__(self, config, rank=None, world_size=None):
super().__init__(config)
if rank is None:
self.rank = get_model_parallel_rank()
else:
self.rank = rank
if world_size is None:
self.world_size = get_model_parallel_world_size()
else:
self.world_size = world_size
self.model = MistralModel(config, rank, world_size)
self.vocab_size = config.vocab_size
self.lm_head = ColumnParallelLinear.create(
config.hidden_size,
config.vocab_size,
bias=False,
rank=rank,
world_size=world_size,
)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
# Initialize weights and apply final processing
self.post_init()