def load_hook()

in local_gemma/attention.py [0:0]


    def load_hook(self, state_dict, prefix, *args):
        if prefix + "q_proj.weight" in state_dict:
            q_proj = state_dict.pop(prefix + "q_proj.weight")
            k_proj = state_dict.pop(prefix + "k_proj.weight")
            v_proj = state_dict.pop(prefix + "v_proj.weight")
            state_dict[prefix + "qkv_proj.weight"] = torch.cat([q_proj, k_proj, v_proj])

            if self.config.attention_bias:
                q_bias = state_dict.pop(prefix + "q_proj.bias")
                k_bias = state_dict.pop(prefix + "k_proj.bias")
                v_bias = state_dict.pop(prefix + "v_proj.bias")
                state_dict[prefix + "qkv_proj.bias"] = torch.cat([q_bias, k_bias, v_bias])