optimum/tpu/modeling_llama.py [1188:1246]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )
        # Initialize weights and apply final processing
        self._register_load_state_dict_pre_hook(self.load_hook)

        # Initialize weights and apply final processing
        self.post_init()

    def load_hook(self, state_dict, _prefix, *_args):
        num_attn_heads = self.config.num_attention_heads
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        hidden_size = self.config.hidden_size

        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            axis_len = tensor.shape[axis]
            split_len = axis_len // self.world_size
            split_start = split_len * self.rank
            split_end = split_start + split_len
            tensor = torch.moveaxis(tensor, axis, 0)
            tensor = tensor[split_start:split_end, ...]
            tensor = torch.moveaxis(tensor, 0, axis)
            return tensor

        for k, v in state_dict.items():
            if re.fullmatch(r"model.layers.\d+.mlp.(gate_proj|up_proj).weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.mlp.down_proj.weight", k):
                v = split(v, 1)
            if re.fullmatch(r"model.layers.\d+.self_attn.(k|v)_proj.weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.self_attn.q_proj.weight", k):
                v = v.reshape(num_attn_heads, head_dim, hidden_size)
                v = split(v, 0)
                v = v.reshape(-1, hidden_size)
            if re.fullmatch(r"model.layers.\d+.self_attn.o_proj.weight", k):
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
                v = split(v, 1)
                v = v.reshape(hidden_size, -1)
            if k == "lm_head.weight":
                v = split(v, 0)
            # Update state_dict
            state_dict[k] = v

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/tpu/modeling_mistral.py [1183:1241]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )
        # Initialize weights and apply final processing
        self._register_load_state_dict_pre_hook(self.load_hook)

        # Initialize weights and apply final processing
        self.post_init()

    def load_hook(self, state_dict, _prefix, *_args):
        num_attn_heads = self.config.num_attention_heads
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        hidden_size = self.config.hidden_size

        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            axis_len = tensor.shape[axis]
            split_len = axis_len // self.world_size
            split_start = split_len * self.rank
            split_end = split_start + split_len
            tensor = torch.moveaxis(tensor, axis, 0)
            tensor = tensor[split_start:split_end, ...]
            tensor = torch.moveaxis(tensor, 0, axis)
            return tensor

        for k, v in state_dict.items():
            if re.fullmatch(r"model.layers.\d+.mlp.(gate_proj|up_proj).weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.mlp.down_proj.weight", k):
                v = split(v, 1)
            if re.fullmatch(r"model.layers.\d+.self_attn.(k|v)_proj.weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.self_attn.q_proj.weight", k):
                v = v.reshape(num_attn_heads, head_dim, hidden_size)
                v = split(v, 0)
                v = v.reshape(-1, hidden_size)
            if re.fullmatch(r"model.layers.\d+.self_attn.o_proj.weight", k):
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
                v = split(v, 1)
                v = v.reshape(hidden_size, -1)
            if k == "lm_head.weight":
                v = split(v, 0)
            # Update state_dict
            state_dict[k] = v

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



