def __init__()

in optimum/tpu/modeling_llama.py [0:0]


    def __init__(self, config, rank=None, world_size=None):
        super().__init__(config)
        if rank is None:
            self.rank = get_model_parallel_rank()
        else:
            self.rank = rank
        if world_size is None:
            self.world_size = get_model_parallel_world_size()
        else:
            self.world_size = world_size
        self.model = LlamaModel(config, rank, world_size)
        self.vocab_size = config.vocab_size
        self.lm_head = ColumnParallelLinear.create(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            rank=self.rank,
            world_size=self.world_size,
        )
        # Initialize weights and apply final processing
        self._register_load_state_dict_pre_hook(self.load_hook)

        # Initialize weights and apply final processing
        self.post_init()