def load_hook()

in optimum/tpu/modeling_gemma.py [0:0]


    def load_hook(self, state_dict, _prefix, *_args):
        num_attn_heads = self.config.num_attention_heads
        head_dim = self.config.head_dim
        hidden_size = self.config.hidden_size

        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            axis_len = tensor.shape[axis]
            split_len = axis_len // self.world_size
            split_start = split_len * self.rank
            split_end = split_start + split_len
            tensor = torch.moveaxis(tensor, axis, 0)
            tensor = tensor[split_start:split_end, ...]
            tensor = torch.moveaxis(tensor, 0, axis)
            return tensor

        for k, v in state_dict.items():
            if re.fullmatch(r"model.layers.\d+.mlp.(gate_proj|up_proj).weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.mlp.down_proj.weight", k):
                v = split(v, 1)
            if re.fullmatch(r"model.layers.\d+.self_attn.(k|v)_proj.weight", k):
                v = split(v, 0)
            if re.fullmatch(r"model.layers.\d+.self_attn.q_proj.weight", k):
                v = v.reshape(num_attn_heads, head_dim, hidden_size)
                v = split(v, 0)
                v = v.reshape(-1, hidden_size)
            if re.fullmatch(r"model.layers.\d+.self_attn.o_proj.weight", k):
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
                v = split(v, 1)
                v = v.reshape(hidden_size, -1)
            # Update state_dict
            state_dict[k] = v