optimum/tpu/modeling_gemma.py [1112:1138]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        hidden_size = self.config.hidden_size

        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            axis_len = tensor.shape[axis]
            split_len = axis_len // self.world_size
            split_start = split_len * self.rank
            split_end = split_start + split_len
            tensor = torch.moveaxis(tensor, axis, 0)
            tensor = tensor[split_start:split_end, ...]
            tensor = torch.moveaxis(tensor, 0, axis)
            return tensor

        for k, v in state_dict.items():
            if re.fullmatch(r"model.layers.\d+.mlp.(gate_proj|up_proj).weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.mlp.down_proj.weight", k):
                v = split(v, 1)
            if re.fullmatch(r"model.layers.\d+.self_attn.(k|v)_proj.weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.self_attn.q_proj.weight", k):
                v = v.reshape(num_attn_heads, head_dim, hidden_size)
                v = split(v, 0)
                v = v.reshape(-1, hidden_size)
            if re.fullmatch(r"model.layers.\d+.self_attn.o_proj.weight", k):
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
                v = split(v, 1)
                v = v.reshape(hidden_size, -1)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/tpu/modeling_mistral.py [1193:1219]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        hidden_size = self.config.hidden_size

        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            axis_len = tensor.shape[axis]
            split_len = axis_len // self.world_size
            split_start = split_len * self.rank
            split_end = split_start + split_len
            tensor = torch.moveaxis(tensor, axis, 0)
            tensor = tensor[split_start:split_end, ...]
            tensor = torch.moveaxis(tensor, 0, axis)
            return tensor

        for k, v in state_dict.items():
            if re.fullmatch(r"model.layers.\d+.mlp.(gate_proj|up_proj).weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.mlp.down_proj.weight", k):
                v = split(v, 1)
            if re.fullmatch(r"model.layers.\d+.self_attn.(k|v)_proj.weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.self_attn.q_proj.weight", k):
                v = v.reshape(num_attn_heads, head_dim, hidden_size)
                v = split(v, 0)
                v = v.reshape(-1, hidden_size)
            if re.fullmatch(r"model.layers.\d+.self_attn.o_proj.weight", k):
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
                v = split(v, 1)
                v = v.reshape(hidden_size, -1)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



